Quantcast
Channel: Active questions tagged python - Stack Overflow
Viewing all articles
Browse latest Browse all 23131

Checking derivative tensor in Pytorch

$
0
0

In this question on Math StackExchange people are discussing the derivative of a function f(x) = Axx'A / (x'AAx) where x is a vector and A is a symmetric, positive semi-definite square matrix.

The derivative of this function at a point x is a tensor. And when "applied" to another vector h it is a matrix. The answers under that post differ in terms of expressions for this matrix, so I would like to check them numerically using Pytorch or Autograd.

Here is my attempt with Pytorch

import torch def P(x, A):    x = x.unsqueeze(1)  # Convert to column vector    vector = torch.matmul(A, x)    denom = (vector.transpose(0, 1) @ vector).squeeze()    P_matrix = (vector @ vector.transpose(0, 1)) / denom    return P_matrix.squeeze()A = torch.tensor([[1.0, 0.5], [0.5, 1.3]], dtype=torch.float32)x = torch.tensor([1.0, 2.0], dtype=torch.float32, requires_grad=True)h = torch.tensor([2.0, -1.0], dtype=torch.float32)Pxh = torch.matmul(P(x, A), h)# compute gradient Pxh.backward()

But this doesn't work. What am I doing wrong?

JAX

I am also happy with a Jax Solution. I tried jax.grad but does not work.


Viewing all articles
Browse latest Browse all 23131

Trending Articles



<script src="https://jsc.adskeeper.com/r/s/rssing.com.1596347.js" async> </script>