I have a relatively simple requirement but surprisingly this does not seem to be straightforward to implement in pytorch. Given a neural network with $P$ parameters that outputs a vector of length $Y$ and a batch of $B$ data inputs, I would like to calculate the gradients of the outputs with respect to the model's parameters.
In other words, I would like the following function:
def calculate_gradients(model, X):""" Args: nn module with P parameters in total that outputs a tensor of size (B, Y). torch tensor of shape (B, .). Returns: torch tensor of shape (B, Y, P)""" # function logic here
Unfortunately, I don't currently see an obvious way of calculating this efficiently, especially without aggregating over the data or target dimensions. A minimal working example below involves looping over input and target dimensions, but surely there is a more efficient way?
import torchfrom torchvision import datasets, transformsimport torch.nn as nn###### SETUP ######class MLP(nn.Module): def __init__(self, input_size, hidden_size, output_size): super(MLP, self).__init__() self.fc1 = nn.Linear(input_size, hidden_size) self.relu = nn.ReLU() self.fc2 = nn.Linear(hidden_size, output_size) def forward(self, x): h = self.fc1(x) pred = self.fc2(self.relu(h)) return predtrain_dataset = datasets.MNIST(root='./data', train=True, download=True, transform=transforms.Compose( [transforms.ToTensor(), transforms.Normalize((0.5,), (0.5,)) ]))train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)X, y = next(iter(train_dataloader)) # take a random batch of datanet = MLP(28*28, 20, 10) # define a network###### CALCULATE GRADIENTS ######def calculate_gradients(model, X): # Create a tensor to hold the gradients gradients = torch.zeros(X.shape[0], 10, sum(p.numel() for p in model.parameters())) # Calculate the gradients for each input and target dimension for i in range(X.shape[0]): for j in range(10): model.zero_grad() output = model(X[i]) # Calculate the gradients grads = torch.autograd.grad(output[j], model.parameters()) # Flatten the gradients and store them gradients[i, j, :] = torch.cat([g.view(-1) for g in grads]) return gradientsgrads = calculate_gradients(net, X.view(X.shape[0], -1))