Here is some python code to reproduce my issue:
import torchn, m = 9, 4x = torch.arange(0, n * m).reshape(n, m)print(x.shape)print(x)# torch.Size([9, 4])# tensor([[ 0, 1, 2, 3],# [ 4, 5, 6, 7],# [ 8, 9, 10, 11],# [12, 13, 14, 15],# [16, 17, 18, 19],# [20, 21, 22, 23],# [24, 25, 26, 27],# [28, 29, 30, 31],# [32, 33, 34, 35]])list_of_indices = [ [], [2, 3], [1], [], [], [], [0, 1, 2, 3], [], [0, 3],]print(list_of_indices)for i, indices in enumerate(list_of_indices): x[i, indices] = -1print(x)# tensor([[ 0, 1, 2, 3],# [ 4, 5, -1, -1],# [ 8, -1, 10, 11],# [12, 13, 14, 15],# [16, 17, 18, 19],# [20, 21, 22, 23],# [-1, -1, -1, -1],# [28, 29, 30, 31],# [-1, 33, 34, -1]])
I have a list of list of indices. I want to set the indices in x
to a specific value (here -1
) using the indices in list_of_indices
. In this list, each sublist correspond to a row of x
, containing the indices to set to -1
for this row. This can be easily done using a for-loop, but I feel like pytorch would allow to do that much more efficiently.
I tried the following:
x[torch.arange(len(list_of_indices)), list_of_indices] = -1
but it resulted in
IndexError: shape mismatch: indexing tensors could not be broadcast together with shapes [9], [9, 0]
I tried to find people having the same problem, but the number of questions about indexing tensors is so large that I might have missed it.