Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 23131

PyTorch CNN model's inference performance depends on batch size

$
0
0

I have a CNN model which is basically just VGG16. I train it with my own data then using the model to infer the same dataset as a part of the model evaluation. The code is roughly as follows:

batch_size = 16misclassified_count = 0data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)with torch.no_grad():    # Loop through the images in batches    for batch_idx, (images, y_trues) in enumerate(data_loader):        images = images.to(device)        output = vgg6(images)        _, predicted = torch.max(output, dim=1)        for i, pred_label in enumerate(predicted):            y_true = y_trues[i].item()            pred_label = pred_label.item()            if pred_label == y_true:                continue            misclassified_count += 1print(f'{misclassified_count} misclassified samples found')

The problem is, as I increase the batch_size, the misclassified samples will decrease:

batch_size=1, 2872 misclassified samples foundbatch_size=2, 2133 misclassified samples foundbatch_size=4, 1637 misclassified samples foundbatch_size=8, 1364 misclassified samples foundbatch_size=16, 1097 misclassified samples found

Any thoughts? Is there any flaw in my code?


Viewing all articles
Browse latest Browse all 23131

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>