提交 ac942196 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix gradient of ScipyScalarWrapperOp with repeated outer inputs

上级 9d5f196c
...@@ -35,7 +35,6 @@ from pytensor.tensor.math import tensordot ...@@ -35,7 +35,6 @@ from pytensor.tensor.math import tensordot
from pytensor.tensor.reshape import pack, unpack from pytensor.tensor.reshape import pack, unpack
from pytensor.tensor.slinalg import solve from pytensor.tensor.slinalg import solve
from pytensor.tensor.variable import TensorVariable, Variable from pytensor.tensor.variable import TensorVariable, Variable
from pytensor.utils import unzip
# scipy.optimize can be slow to import, and will not be used by most users # scipy.optimize can be slow to import, and will not be used by most users
...@@ -277,6 +276,7 @@ class ScipyScalarWrapperOp(ScipyWrapperOp): ...@@ -277,6 +276,7 @@ class ScipyScalarWrapperOp(ScipyWrapperOp):
inner_fx = self.inner_outputs[0] inner_fx = self.inner_outputs[0]
if is_minimization: if is_minimization:
# The implicit function in minimization is grad(x, theta) == 0
inner_fx = grad(inner_fx, inner_x) inner_fx = grad(inner_fx, inner_x)
df_dx, *arg_grads = grad( df_dx, *arg_grads = grad(
...@@ -287,32 +287,35 @@ class ScipyScalarWrapperOp(ScipyWrapperOp): ...@@ -287,32 +287,35 @@ class ScipyScalarWrapperOp(ScipyWrapperOp):
return_disconnected="disconnected", return_disconnected="disconnected",
) )
outer_arg_grad_map = dict(zip(args, arg_grads)) args_to_diff: tuple[bool, ...] = tuple(
valid_args_and_grads = [ not isinstance(g.type, DisconnectedType | NullType) for g in arg_grads
(arg, g) )
for arg, g in outer_arg_grad_map.items()
if not isinstance(g.type, DisconnectedType | NullType)
]
if len(valid_args_and_grads) == 0: if not any(args_to_diff):
# No differentiable arguments, return disconnected gradients # No differentiable arguments, return disconnected gradients
return arg_grads return arg_grads
outer_args_to_diff, df_dthetas = unzip(valid_args_and_grads, n=2) df_dthetas = [g for g, to_diff in zip(arg_grads, args_to_diff) if to_diff]
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True)) # Make gradient an expression of the outer variables
df_dx_star, *df_dthetas_stars = graph_replace( df_dx_star, *df_dthetas_stars = graph_replace(
[df_dx, *df_dthetas], replace=replace [df_dx, *df_dthetas], replace=tuple(zip(fgraph.inputs, (x_star, *args)))
) )
arg_to_grad = dict(zip(outer_args_to_diff, df_dthetas_stars)) grad_wrt_args = []
df_dthetas_iter = iter(df_dthetas_stars)
for i, (arg, to_diff) in enumerate(zip(args, args_to_diff)):
if not to_diff:
# Store the null grad we got from the initial `grad` call
g = arg_grads[i]
assert isinstance(g.type, NullType | DisconnectedType)
else:
# Compute non-null grad and chain with output_grad
df_dtheta_star = next(df_dthetas_iter)
g = (-df_dtheta_star / df_dx_star) * output_grad
grad_wrt_args.append(g)
grad_wrt_args = [ assert next(df_dthetas_iter, None) is None, "Iterator was not exhausted"
(-arg_to_grad[arg] / df_dx_star) * output_grad
if arg in arg_to_grad
else outer_arg_grad_map[arg]
for arg in args
]
return grad_wrt_args return grad_wrt_args
......
...@@ -12,6 +12,7 @@ from pytensor.graph import Apply, Op, Type ...@@ -12,6 +12,7 @@ from pytensor.graph import Apply, Op, Type
from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar
from pytensor.tensor.optimize import ( from pytensor.tensor.optimize import (
MinimizeOp, MinimizeOp,
MinimizeScalarOp,
minimize, minimize,
minimize_scalar, minimize_scalar,
root, root,
...@@ -604,7 +605,15 @@ def test_vectorize_root_gradients(): ...@@ -604,7 +605,15 @@ def test_vectorize_root_gradients():
np.testing.assert_allclose(a_grad_grid_val, analytical_a_grad_grid) np.testing.assert_allclose(a_grad_grid_val, analytical_a_grad_grid)
def test_minimize_grad_duplicate_input_connected_and_disconnected(): @pytest.mark.parametrize(
"op_cls, op_kwargs",
[
(MinimizeOp, {"method": "BFGS"}),
(MinimizeScalarOp, {"method": "brent"}),
],
ids=["MinimizeOp", "MinimizeScalarOp"],
)
def test_minimize_grad_duplicate_input_connected_and_disconnected(op_cls, op_kwargs):
"""Regression test: when the same outer variable is passed for both a connected """Regression test: when the same outer variable is passed for both a connected
and a disconnected inner arg, the gradient should not crash. and a disconnected inner arg, the gradient should not crash.
...@@ -618,7 +627,7 @@ def test_minimize_grad_duplicate_input_connected_and_disconnected(): ...@@ -618,7 +627,7 @@ def test_minimize_grad_duplicate_input_connected_and_disconnected():
# 'args[[0, 2]]' are connected, while 'args[1]' is disconnected # 'args[[0, 2]]' are connected, while 'args[1]' is disconnected
objective = (x - (args[0] + args[2])) ** 2 + pt.second(args[1], 0) objective = (x - (args[0] + args[2])) ** 2 + pt.second(args[1], 0)
minimize_op = MinimizeOp(x, *args, objective=objective, method="BFGS") minimize_op = op_cls(x, *args, objective=objective, **op_kwargs)
# Use the same input for each of args (this can happen after rewrites/merging) # Use the same input for each of args (this can happen after rewrites/merging)
a = pt.scalar("a") a = pt.scalar("a")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论