Recently I have been trying to reproduce the result of https://github.com/zexuanqiu/CIBHash, however, I run into the loss explosion every time after evaluating. I am using cifar-10 dataset from official site.
For instance,
- Run the code with
python main.py cifar16 --train --dataset cifar10 --encode_length 16 --cuda(defaultly settingvalidate_frequency=20), the code would evaluate the performance of itself afterepoch=20and would continue training, the loss explosion problem occurs inepoch=21. - Run the code with
python main.py cifar16 --train --dataset cifar10 --encode_length 16 --cuda --validate_frequency=3, settingvalidate_frequency=3, the loss explosion occurs inepoch=4, stably.
Here's its run_training_session function:
def run_training_session(self, run_num, logger): self.train() # Scramble hyperparameters if number of runs is greater than 1. if self.hparams.num_runs > 1: logger.log('RANDOM RUN: %d/%d' % (run_num, self.hparams.num_runs)) for hparam, values in self.get_hparams_grid().items(): assert hasattr(self.hparams, hparam) self.hparams.__dict__[hparam] = random.choice(values) random.seed(self.hparams.seed) torch.manual_seed(self.hparams.seed) self.define_parameters() # if encode_length is 16, then al least 80 epochs! if self.hparams.encode_length == 16: self.hparams.epochs = max(80, self.hparams.epochs) logger.log('hparams: %s' % self.flag_hparams()) device = torch.device('cuda' if self.hparams.cuda else 'cpu') self.to(device) optimizer = self.configure_optimizers() train_loader, val_loader, _, database_loader = self.data.get_loaders( self.hparams.batch_size, self.hparams.num_workers, shuffle_train=True, get_test=False) best_val_perf = float('-inf') best_state_dict = None bad_epochs = 0 try: for epoch in range(1, self.hparams.epochs + 1): forward_sum = {} num_steps = 0 for batch_num, batch in enumerate(train_loader): optimizer.zero_grad() imgi, imgj, _ = batch imgi = imgi.to(device) imgj = imgj.to(device) forward = self.forward(imgi, imgj, device) for key in forward: if key in forward_sum: forward_sum[key] += forward[key] else: forward_sum[key] = forward[key] num_steps += 1 if math.isnan(forward_sum['loss']): logger.log('Stopping epoch because loss is NaN') break forward['loss'].backward() optimizer.step() if math.isnan(forward_sum['loss']): logger.log('Stopping training session because loss is NaN') break logger.log('End of epoch {:3d}'.format(epoch), False) logger.log(''.join([' | {:s} {:8.4f}'.format( key, forward_sum[key] / num_steps) for key in forward_sum]), True) if epoch % self.hparams.validate_frequency == 0: print('evaluating...') val_perf = self.evaluate(database_loader, val_loader, self.data.topK, device) logger.log(' | val perf {:8.4f}'.format(val_perf), False) if val_perf > best_val_perf: best_val_perf = val_perf bad_epochs = 0 logger.log('\t\t*Best model so far, deep copying*') best_state_dict = deepcopy(self.state_dict()) else: bad_epochs += 1 logger.log('\t\tBad epoch %d' % bad_epochs) if bad_epochs > self.hparams.num_bad_epochs: break except KeyboardInterrupt: logger.log('-' * 89) logger.log('Exiting from training early') return best_state_dict, best_val_perfAnd here's the forward function of CIBHash model:
def forward(self, imgi, imgj, device): imgi = self.vgg.features(imgi) imgi = imgi.view(imgi.size(0), -1) imgi = self.vgg.classifier(imgi) prob_i = torch.sigmoid(self.encoder(imgi)) z_i = hash_layer(prob_i - torch.empty_like(prob_i).uniform_().to(prob_i.device)) imgj = self.vgg.features(imgj) imgj = imgj.view(imgj.size(0), -1) imgj = self.vgg.classifier(imgj) prob_j = torch.sigmoid(self.encoder(imgj)) z_j = hash_layer(prob_j - torch.empty_like(prob_j).uniform_().to(prob_j.device)) kl_loss = (self.compute_kl(prob_i, prob_j) + self.compute_kl(prob_j, prob_i)) / 2 contra_loss = self.criterion(z_i, z_j, device) loss = contra_loss + self.hparams.weight * kl_loss return {'loss': loss, 'contra_loss': contra_loss, 'kl_loss': kl_loss}I have tried to replace to z_i and z_j as instructed in https://github.com/zexuanqiu/CIBHash/issues/6, however, it failed to prevent the NaN problem.
I have tried gradient_gripping method but it came into no use.
According to the reply of the author, they didn't come across any NaN problem when they were training the model. (https://github.com/zexuanqiu/CIBHash/issues/7)
I expect the code to finish training session without the occurrence of NaN problem. Can anyone be so kind to tell me what factor may cause this problem? Or is there any potential solution to the NaN loss problem?
