提交 91b73211 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Tweak ScipyVectorWrapperOp compute_implicit_gradients for readability

上级 ac942196
......@@ -361,7 +361,7 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
Notes
-----
The gradents are computed using the implicit function theorem. Given a fuction `f(x, theta) = 0`, and a function
The gradients are computed using the implicit function theorem. Given a function `f(x, theta) = 0`, and a function
`x_star(theta)` that, given input parameters theta returns `x` such that `f(x_star(theta), theta) = 0`, we can
compute the gradients of `x_star` with respect to `theta` as follows:
......@@ -387,8 +387,12 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
fgraph = self.fgraph
inner_x, *inner_args = self.inner_inputs
implicit_f = self.inner_outputs[0]
if is_minimization:
# The implicit function in minimization is grad(x, theta) == 0
implicit_f = grad(implicit_f, inner_x)
df_dx, *arg_grads = grad(
# Call grad to see what arguments are connected
_, *arg_grads = grad(
implicit_f.sum(),
[inner_x, *inner_args],
disconnected_inputs="ignore",
......@@ -404,9 +408,6 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
# No differentiable arguments, return disconnected/null gradients
return arg_grads
if is_minimization:
implicit_f = grad(implicit_f, inner_x)
# Gradients are computed using the inner graph of the optimization op, not the actual inputs/outputs of the op.
packed_inner_args, packed_arg_shapes, implicit_f = pack_inputs_of_objective(
implicit_f,
......@@ -426,11 +427,11 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
# Replace inner inputs (abstract dummies) with outer inputs (the actual user-provided symbols)
# at the solution point. Innner arguments aren't needed anymore, delete them to avoid accidental references.
del inner_x
del inner_args
del inner_x, inner_args
inner_to_outer_map = tuple(zip(fgraph.inputs, (x_star, *args)))
df_dx_star, df_dtheta_star = graph_replace(
[df_dx, df_dtheta], inner_to_outer_map
[df_dx, df_dtheta],
replace=inner_to_outer_map,
)
if df_dtheta_star.ndim == 0 or df_dx_star.ndim == 0:
......@@ -455,19 +456,17 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
for i, (arg, to_diff) in enumerate(zip(args, args_to_diff)):
if not to_diff:
# Store the null grad we got from the initial `grad` call
null_grad = arg_grads[i]
assert isinstance(null_grad.type, NullType | DisconnectedType)
final_grads.append(null_grad)
continue
arg_grad = next(grad_wrt_args_iter)
if arg_grad.ndim > 0 and output_grad.ndim > 0:
g = tensordot(output_grad, arg_grad, [[0], [0]])
g = arg_grads[i]
assert isinstance(g.type, NullType | DisconnectedType)
else:
g = arg_grad * output_grad
if isinstance(arg.type, ScalarType) and isinstance(g, TensorVariable):
g = scalar_from_tensor(g)
# Compute non-null grad and chain with output_grad
arg_grad = next(grad_wrt_args_iter)
if arg_grad.ndim > 0 and output_grad.ndim > 0:
g = tensordot(output_grad, arg_grad, [[0], [0]])
else:
g = arg_grad * output_grad
if isinstance(arg.type, ScalarType) and isinstance(g, TensorVariable):
g = scalar_from_tensor(g)
final_grads.append(g)
assert next(grad_wrt_args_iter, None) is None, "Iterator was not exhausted"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论