Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 14011

Issue with Prediction Shape in Hugging Face Accelerated Inference

$
0
0

I am currently facing an issue while utilizing the Hugging Face accelerator for inferencing on a test dataset.

The code snippet below outlines my inferencing process, where I have four GPU devices, and I set num_processes=4 in the notebook_launcher:

def infer_ddp_accelerate(model):    accelerator = Accelerator(mixed_precision="fp16")    test_dataset = CustomDataset(test)    test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)    # Send everything through `accelerator.prepare`    model, test_loader = accelerator.prepare(model, test_loader)    accelerator.print(f"Testing on {len(test_loader)} samples")    model.eval()    prediction = []    with torch.no_grad():        for batch in tqdm(test_loader):            inputs = batch            input_ids = inputs['input_ids']            attention_mask = inputs['attention_mask']            outputs = model(input_ids, attention_mask)            scores = torch.sigmoid(outputs.squeeze())            prediction.append(accelerator.gather(scores))    prediction = torch.cat(prediction, axis=0).cpu().numpy()    accelerator.print(f"{prediction.shape}")    accelerator.print(f"{prediction}")notebook_launcher(infer_ddp_accelerate, args=(model,), num_processes=4)

The test dataset has 57949 rows, and I expect the model score for each row. However, the prediction.shape in the code is coming out to be 57984.

I suspect that due to the presence of four GPUs, the actual batch size becomes 32 * 4 = 128, resulting in 453 batches (ceil(57949/128) = 453).

As 453 * 128 = 57984, the accelerator is somehow giving the final prediction shape as 57984.

How can I ensure that the prediction shape is exactly 57949 in the above code?

Is taking the first 57949 elements of the 57984 correct, or is there a better approach to handle this?

Thank you.


Viewing all articles
Browse latest Browse all 14011

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>