Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 17360

Pytorch how to efficiently calculate gradients of all outputs with respect to parameters

$
0
0

I have a relatively simple requirement but surprisingly this does not seem to be straightforward to implement in pytorch. Given a neural network with $P$ parameters that outputs a vector of length $Y$ and a batch of $B$ data inputs, I would like to calculate the gradients of the outputs with respect to the model's parameters.

In other words, I would like the following function:

def calculate_gradients(model, X):"""    Args:        nn module with P parameters in total that outputs a tensor of size (B, Y).        torch tensor of shape (B, .).    Returns:        torch tensor of shape (B, Y, P)"""    # function logic here

Unfortunately, I don't currently see an obvious way of calculating this efficiently, especially without aggregating over the data or target dimensions. A minimal working example below involves looping over input and target dimensions, but surely there is a more efficient way?

import torchfrom torchvision import datasets, transformsimport torch.nn as nn###### SETUP ######class MLP(nn.Module):    def __init__(self, input_size, hidden_size, output_size):        super(MLP, self).__init__()        self.fc1 = nn.Linear(input_size, hidden_size)        self.relu = nn.ReLU()        self.fc2 = nn.Linear(hidden_size, output_size)    def forward(self, x):        h = self.fc1(x)        pred = self.fc2(self.relu(h))        return predtrain_dataset = datasets.MNIST(root='./data', train=True, download=True,                             transform=transforms.Compose(                                [transforms.ToTensor(),                                    transforms.Normalize((0.5,), (0.5,))        ]))train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=2, shuffle=False)X, y = next(iter(train_dataloader))  # take a random batch of datanet = MLP(28*28, 20, 10)  # define a network###### CALCULATE GRADIENTS ######def calculate_gradients(model, X):    # Create a tensor to hold the gradients    gradients = torch.zeros(X.shape[0], 10, sum(p.numel() for p in model.parameters()))    # Calculate the gradients for each input and target dimension    for i in range(X.shape[0]):        for j in range(10):            model.zero_grad()            output = model(X[i])            # Calculate the gradients            grads = torch.autograd.grad(output[j], model.parameters())            # Flatten the gradients and store them            gradients[i, j, :] = torch.cat([g.view(-1) for g in grads])    return gradientsgrads = calculate_gradients(net, X.view(X.shape[0], -1))

Viewing all articles
Browse latest Browse all 17360

Latest Images

Trending Articles



Latest Images

<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>