Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 23390

Why am I keep running into the NaN problem when training CIBHash model?

$
0
0

Recently I have been trying to reproduce the result of https://github.com/zexuanqiu/CIBHash, however, I run into the loss explosion every time after evaluating. I am using cifar-10 dataset from official site.

For instance,

  • Run the code with python main.py cifar16 --train --dataset cifar10 --encode_length 16 --cuda (defaultly setting validate_frequency=20), the code would evaluate the performance of itself after epoch=20 and would continue training, the loss explosion problem occurs in epoch=21.
  • Run the code with python main.py cifar16 --train --dataset cifar10 --encode_length 16 --cuda --validate_frequency=3, setting validate_frequency=3, the loss explosion occurs in epoch=4, stably.

Sample output:sample output

Here's its run_training_session function:

def run_training_session(self, run_num, logger):        self.train()        # Scramble hyperparameters if number of runs is greater than 1.        if self.hparams.num_runs > 1:            logger.log('RANDOM RUN: %d/%d' % (run_num, self.hparams.num_runs))            for hparam, values in self.get_hparams_grid().items():                assert hasattr(self.hparams, hparam)                self.hparams.__dict__[hparam] = random.choice(values)        random.seed(self.hparams.seed)        torch.manual_seed(self.hparams.seed)        self.define_parameters()        # if encode_length is 16, then al least 80 epochs!        if self.hparams.encode_length == 16:            self.hparams.epochs = max(80, self.hparams.epochs)        logger.log('hparams: %s' % self.flag_hparams())        device = torch.device('cuda' if self.hparams.cuda else 'cpu')        self.to(device)        optimizer = self.configure_optimizers()        train_loader, val_loader, _, database_loader = self.data.get_loaders(            self.hparams.batch_size, self.hparams.num_workers,            shuffle_train=True, get_test=False)        best_val_perf = float('-inf')        best_state_dict = None        bad_epochs = 0        try:            for epoch in range(1, self.hparams.epochs + 1):                forward_sum = {}                num_steps = 0                for batch_num, batch in enumerate(train_loader):                    optimizer.zero_grad()                    imgi, imgj, _ = batch                    imgi = imgi.to(device)                    imgj = imgj.to(device)                    forward = self.forward(imgi, imgj, device)                    for key in forward:                        if key in forward_sum:                            forward_sum[key] += forward[key]                        else:                            forward_sum[key] = forward[key]                    num_steps += 1                    if math.isnan(forward_sum['loss']):                        logger.log('Stopping epoch because loss is NaN')                        break                    forward['loss'].backward()                    optimizer.step()                if math.isnan(forward_sum['loss']):                    logger.log('Stopping training session because loss is NaN')                    break                logger.log('End of epoch {:3d}'.format(epoch), False)                logger.log(''.join([' | {:s} {:8.4f}'.format(                    key, forward_sum[key] / num_steps)                                     for key in forward_sum]), True)                if epoch % self.hparams.validate_frequency == 0:                    print('evaluating...')                    val_perf = self.evaluate(database_loader, val_loader, self.data.topK, device)                    logger.log(' | val perf {:8.4f}'.format(val_perf), False)                    if val_perf > best_val_perf:                        best_val_perf = val_perf                        bad_epochs = 0                        logger.log('\t\t*Best model so far, deep copying*')                        best_state_dict = deepcopy(self.state_dict())                    else:                        bad_epochs += 1                        logger.log('\t\tBad epoch %d' % bad_epochs)                    if bad_epochs > self.hparams.num_bad_epochs:                        break        except KeyboardInterrupt:            logger.log('-' * 89)            logger.log('Exiting from training early')        return best_state_dict, best_val_perf

And here's the forward function of CIBHash model:

 def forward(self, imgi, imgj, device):        imgi = self.vgg.features(imgi)        imgi = imgi.view(imgi.size(0), -1)        imgi = self.vgg.classifier(imgi)        prob_i = torch.sigmoid(self.encoder(imgi))        z_i = hash_layer(prob_i - torch.empty_like(prob_i).uniform_().to(prob_i.device))        imgj = self.vgg.features(imgj)        imgj = imgj.view(imgj.size(0), -1)        imgj = self.vgg.classifier(imgj)        prob_j = torch.sigmoid(self.encoder(imgj))        z_j = hash_layer(prob_j - torch.empty_like(prob_j).uniform_().to(prob_j.device))        kl_loss = (self.compute_kl(prob_i, prob_j) + self.compute_kl(prob_j, prob_i)) / 2        contra_loss = self.criterion(z_i, z_j, device)        loss = contra_loss + self.hparams.weight * kl_loss        return {'loss': loss, 'contra_loss': contra_loss, 'kl_loss': kl_loss}

I have tried to replace to z_i and z_j as instructed in https://github.com/zexuanqiu/CIBHash/issues/6, however, it failed to prevent the NaN problem.

I have tried gradient_gripping method but it came into no use.

According to the reply of the author, they didn't come across any NaN problem when they were training the model. (https://github.com/zexuanqiu/CIBHash/issues/7)

I expect the code to finish training session without the occurrence of NaN problem. Can anyone be so kind to tell me what factor may cause this problem? Or is there any potential solution to the NaN loss problem?


Viewing all articles
Browse latest Browse all 23390

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>