提交 9d5f196c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix disconnected optimize gradient bug

上级 d84cd641
...@@ -393,32 +393,25 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): ...@@ -393,32 +393,25 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
return_disconnected="disconnected", return_disconnected="disconnected",
) )
inner_args_to_diff = [ args_to_diff: tuple[bool, ...] = tuple(
arg not isinstance(g.type, DisconnectedType | NullType) for g in arg_grads
for arg, g in zip(inner_args, arg_grads) )
if not isinstance(g.type, DisconnectedType | NullType)
]
if len(inner_args_to_diff) == 0: if not args_to_diff:
# No differentiable arguments, return disconnected/null gradients # No differentiable arguments, return disconnected/null gradients
return arg_grads return arg_grads
outer_args_to_diff = [
arg
for inner_arg, arg in zip(inner_args, args)
if inner_arg in inner_args_to_diff
]
invalid_grad_map = {
arg: g for arg, g in zip(args, arg_grads) if arg not in outer_args_to_diff
}
if is_minimization: if is_minimization:
implicit_f = grad(implicit_f, inner_x) implicit_f = grad(implicit_f, inner_x)
# Gradients are computed using the inner graph of the optimization op, not the actual inputs/outputs of the op. # Gradients are computed using the inner graph of the optimization op, not the actual inputs/outputs of the op.
packed_inner_args, packed_arg_shapes, implicit_f = pack_inputs_of_objective( packed_inner_args, packed_arg_shapes, implicit_f = pack_inputs_of_objective(
implicit_f, implicit_f,
inner_args_to_diff, [
inner_arg
for inner_arg, to_diff in zip(inner_args, args_to_diff)
if to_diff
],
) )
df_dx, df_dtheta = jacobian( df_dx, df_dtheta = jacobian(
...@@ -432,7 +425,7 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): ...@@ -432,7 +425,7 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
# at the solution point. Innner arguments aren't needed anymore, delete them to avoid accidental references. # at the solution point. Innner arguments aren't needed anymore, delete them to avoid accidental references.
del inner_x del inner_x
del inner_args del inner_args
inner_to_outer_map = dict(zip(fgraph.inputs, (x_star, *args))) inner_to_outer_map = tuple(zip(fgraph.inputs, (x_star, *args)))
df_dx_star, df_dtheta_star = graph_replace( df_dx_star, df_dtheta_star = graph_replace(
[df_dx, df_dtheta], inner_to_outer_map [df_dx, df_dtheta], inner_to_outer_map
) )
...@@ -454,16 +447,18 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): ...@@ -454,16 +447,18 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
else: else:
grad_wrt_args = [grad_wrt_args_packed] grad_wrt_args = [grad_wrt_args_packed]
arg_to_grad = dict(zip(outer_args_to_diff, grad_wrt_args))
final_grads = [] final_grads = []
for arg in args: grad_wrt_args_iter = iter(grad_wrt_args)
arg_grad = arg_to_grad.get(arg, None) for i, (arg, to_diff) in enumerate(zip(args, args_to_diff)):
if not to_diff:
if arg_grad is None: # Store the null grad we got from the initial `grad` call
final_grads.append(invalid_grad_map[arg]) null_grad = arg_grads[i]
assert isinstance(null_grad.type, NullType | DisconnectedType)
final_grads.append(null_grad)
continue continue
arg_grad = next(grad_wrt_args_iter)
if arg_grad.ndim > 0 and output_grad.ndim > 0: if arg_grad.ndim > 0 and output_grad.ndim > 0:
g = tensordot(output_grad, arg_grad, [[0], [0]]) g = tensordot(output_grad, arg_grad, [[0], [0]])
else: else:
...@@ -472,6 +467,8 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): ...@@ -472,6 +467,8 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
g = scalar_from_tensor(g) g = scalar_from_tensor(g)
final_grads.append(g) final_grads.append(g)
assert next(grad_wrt_args_iter, None) is None, "Iterator was not exhausted"
return final_grads return final_grads
......
...@@ -10,7 +10,13 @@ from pytensor.gradient import ( ...@@ -10,7 +10,13 @@ from pytensor.gradient import (
) )
from pytensor.graph import Apply, Op, Type from pytensor.graph import Apply, Op, Type
from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar from pytensor.tensor import alloc, scalar, scalar_from_tensor, tensor_from_scalar
from pytensor.tensor.optimize import minimize, minimize_scalar, root, root_scalar from pytensor.tensor.optimize import (
MinimizeOp,
minimize,
minimize_scalar,
root,
root_scalar,
)
from tests import unittest_tools as utt from tests import unittest_tools as utt
...@@ -596,3 +602,33 @@ def test_vectorize_root_gradients(): ...@@ -596,3 +602,33 @@ def test_vectorize_root_gradients():
np.testing.assert_allclose(solution_grid_val, analytical_solution_grid) np.testing.assert_allclose(solution_grid_val, analytical_solution_grid)
np.testing.assert_allclose(a_grad_grid_val, analytical_a_grad_grid) np.testing.assert_allclose(a_grad_grid_val, analytical_a_grad_grid)
def test_minimize_grad_duplicate_input_connected_and_disconnected():
"""Regression test: when the same outer variable is passed for both a connected
and a disconnected inner arg, the gradient should not crash.
The old code used dict(zip(args, grads)) which silently overwrote entries when
the same outer variable appeared multiple times, returning a valid gradient for
a position that should have been disconnected.
"""
x = pt.scalar("x")
args = pt.scalars("a0", "a1", "a2")
# 'args[[0, 2]]' are connected, while 'args[1]' is disconnected
objective = (x - (args[0] + args[2])) ** 2 + pt.second(args[1], 0)
minimize_op = MinimizeOp(x, *args, objective=objective, method="BFGS")
# Use the same input for each of args (this can happen after rewrites/merging)
a = pt.scalar("a")
solution, _success = minimize_op(x, a, a, a)
assert minimize_op.connection_pattern(minimize_op) == [
[True, False],
[True, False],
[False, False],
[True, False],
]
np.testing.assert_allclose(pt.grad(solution, a).eval({x: np.pi, a: 0}), 2.0)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论