In this question on Math StackExchange people are discussing the derivative of a function f(x) = Axx'A / (x'AAx) where x is a vector and A is a symmetric, positive semi-definite square matrix.
The derivative of this function at a point x is a tensor. And when "applied" to another vector h it is a matrix. The answers under that post differ in terms of expressions for this matrix, so I would like to check them numerically using Pytorch or Autograd.
Here is my attempt with Pytorch
import torch def P(x, A): x = x.unsqueeze(1) # Convert to column vector vector = torch.matmul(A, x) denom = (vector.transpose(0, 1) @ vector).squeeze() P_matrix = (vector @ vector.transpose(0, 1)) / denom return P_matrix.squeeze()A = torch.tensor([[1.0, 0.5], [0.5, 1.3]], dtype=torch.float32)x = torch.tensor([1.0, 2.0], dtype=torch.float32, requires_grad=True)h = torch.tensor([2.0, -1.0], dtype=torch.float32)Pxh = torch.matmul(P(x, A), h)# compute gradient Pxh.backward()But this doesn't work. What am I doing wrong?
JAX
I am also happy with a Jax Solution. I tried jax.grad but does not work.