提交 ac942196 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix gradient of ScipyScalarWrapperOp with repeated outer inputs

上级 9d5f196c
......@@ -35,7 +35,6 @@ from pytensor.tensor.math import tensordot
from pytensor.tensor.reshape import pack, unpack
from pytensor.tensor.slinalg import solve
from pytensor.tensor.variable import TensorVariable, Variable
from pytensor.utils import unzip
# scipy.optimize can be slow to import, and will not be used by most users
......@@ -277,6 +276,7 @@ class ScipyScalarWrapperOp(ScipyWrapperOp):
inner_fx = self.inner_outputs[0]
if is_minimization:
# The implicit function in minimization is grad(x, theta) == 0
inner_fx = grad(inner_fx, inner_x)
df_dx, *arg_grads = grad(
......@@ -287,32 +287,35 @@ class ScipyScalarWrapperOp(ScipyWrapperOp):
return_disconnected="disconnected",
)
outer_arg_grad_map = dict(zip(args, arg_grads))
valid_args_and_grads = [
(arg, g)
for arg, g in outer_arg_grad_map.items()
if not isinstance(g.type, DisconnectedType | NullType)
]
args_to_diff: tuple[bool, ...] = tuple(
not isinstance(g.type, DisconnectedType | NullType) for g in arg_grads
)
if len(valid_args_and_grads) == 0:
if not any(args_to_diff):
# No differentiable arguments, return disconnected gradients
return arg_grads
outer_args_to_diff, df_dthetas = unzip(valid_args_and_grads, n=2)
df_dthetas = [g for g, to_diff in zip(arg_grads, args_to_diff) if to_diff]
replace = dict(zip(fgraph.inputs, (x_star, *args), strict=True))
# Make gradient an expression of the outer variables
df_dx_star, *df_dthetas_stars = graph_replace(
[df_dx, *df_dthetas], replace=replace
[df_dx, *df_dthetas], replace=tuple(zip(fgraph.inputs, (x_star, *args)))
)
arg_to_grad = dict(zip(outer_args_to_diff, df_dthetas_stars))
grad_wrt_args = []
df_dthetas_iter = iter(df_dthetas_stars)
for i, (arg, to_diff) in enumerate(zip(args, args_to_diff)):
if not to_diff:
# Store the null grad we got from the initial `grad` call
g = arg_grads[i]
assert isinstance(g.type, NullType | DisconnectedType)
else:
# Compute non-null grad and chain with output_grad
df_dtheta_star = next(df_dthetas_iter)
g = (-df_dtheta_star / df_dx_star) * output_grad
grad_wrt_args.append(g)
grad_wrt_args = [
(-arg_to_grad[arg] / df_dx_star) * output_grad
if arg in arg_to_grad
else outer_arg_grad_map[arg]
for arg in args
]
assert next(df_dthetas_iter, None) is None, "Iterator was not exhausted"
return grad_wrt_args
......
......@@ -12,6 +12,7 @@ from pytensor.graph import Apply, Op, Type
from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar
from pytensor.tensor.optimize import (
MinimizeOp,
MinimizeScalarOp,
minimize,
minimize_scalar,
root,
......@@ -604,7 +605,15 @@ def test_vectorize_root_gradients():
np.testing.assert_allclose(a_grad_grid_val, analytical_a_grad_grid)
def test_minimize_grad_duplicate_input_connected_and_disconnected():
@pytest.mark.parametrize(
"op_cls, op_kwargs",
[
(MinimizeOp, {"method": "BFGS"}),
(MinimizeScalarOp, {"method": "brent"}),
],
ids=["MinimizeOp", "MinimizeScalarOp"],
)
def test_minimize_grad_duplicate_input_connected_and_disconnected(op_cls, op_kwargs):
"""Regression test: when the same outer variable is passed for both a connected
and a disconnected inner arg, the gradient should not crash.
......@@ -618,7 +627,7 @@ def test_minimize_grad_duplicate_input_connected_and_disconnected():
# 'args[[0, 2]]' are connected, while 'args[1]' is disconnected
objective = (x - (args[0] + args[2])) ** 2 + pt.second(args[1], 0)
minimize_op = MinimizeOp(x, *args, objective=objective, method="BFGS")
minimize_op = op_cls(x, *args, objective=objective, **op_kwargs)
# Use the same input for each of args (this can happen after rewrites/merging)
a = pt.scalar("a")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论