J'ai deux variables, x
et theta
. J'essaie de réduire mes pertes à l'égard de theta
seulement, mais comme une partie de ma perte de fonction j'ai besoin de la dérivée d'une fonction différente (f
) à l'égard de x
. Ce dérivé lui-même n'est pas pertinent pour la minimisation, la seule de sa sortie. Toutefois, lorsque la mise en œuvre de cette PyTorch j'obtiens une erreur à l'Exécution.
Un exemple minimal est comme suit:
# minimal example of two different autograds
import torch
from torch.autograd.functional import jacobian
def f(theta, x):
return torch.sum(theta * x ** 2)
def df(theta, x):
J = jacobian(lambda x: f(theta, x), x)
return J
# example evaluations of the autograd gradient
x = torch.tensor([1., 2.])
theta = torch.tensor([1., 1.], requires_grad = True)
# derivative should be 2*theta*x (same as an analytical)
with torch.no_grad():
print(df(theta, x))
print(2*theta*x)
tenseur([2., 4.])
tenseur([2., 4.])
# define some arbitrary loss as a fn of theta
loss = torch.sum(df(theta, x)**2)
loss.backward()
donne l'erreur suivante
RuntimeError: l'élément 0 de tenseurs ne nécessite pas de grad et n'ont pas de grad_fn
Si je fournis de façon analytique dérivés (2*theta*x
), il fonctionne très bien:
loss = torch.sum((2*theta*x)**2)
loss.backward()
Est-il un moyen de le faire dans PyTorch? Ou suis-je limité en quelque sorte?
Laissez-moi savoir si quelqu'un a besoin de plus de détails.
PS
Je m'imagine la solution est quelque chose de similaire à la façon dont JAX ne autograd, comme c'est ce que je suis plus familier avec. Ce que je veux dire ici, c'est que dans JAX je crois que vous faites juste:
from jax import grad
df = grad(lambda x: f(theta, x))
et puis df
ce serait simplement une fonction qui peut être appelée à tout moment. Mais est PyTorch le même? Ou est-il un conflit à l'intérieur de .backward()
que la cause de cette erreur?
create_graph
argument, parce que je ne veux pas qu'il soit inclus dans mon.backward()
appel. Dans ce cas, pourquoi est-il me donner une erreur? Je ne comprends pas le message d'erreur.