提交 91b73211 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Tweak ScipyVectorWrapperOp compute_implicit_gradients for readability

上级 ac942196
...@@ -361,7 +361,7 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): ...@@ -361,7 +361,7 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
Notes Notes
----- -----
The gradents are computed using the implicit function theorem. Given a fuction `f(x, theta) = 0`, and a function The gradients are computed using the implicit function theorem. Given a function `f(x, theta) = 0`, and a function
`x_star(theta)` that, given input parameters theta returns `x` such that `f(x_star(theta), theta) = 0`, we can `x_star(theta)` that, given input parameters theta returns `x` such that `f(x_star(theta), theta) = 0`, we can
compute the gradients of `x_star` with respect to `theta` as follows: compute the gradients of `x_star` with respect to `theta` as follows:
...@@ -387,8 +387,12 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): ...@@ -387,8 +387,12 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
fgraph = self.fgraph fgraph = self.fgraph
inner_x, *inner_args = self.inner_inputs inner_x, *inner_args = self.inner_inputs
implicit_f = self.inner_outputs[0] implicit_f = self.inner_outputs[0]
if is_minimization:
# The implicit function in minimization is grad(x, theta) == 0
implicit_f = grad(implicit_f, inner_x)
df_dx, *arg_grads = grad( # Call grad to see what arguments are connected
_, *arg_grads = grad(
implicit_f.sum(), implicit_f.sum(),
[inner_x, *inner_args], [inner_x, *inner_args],
disconnected_inputs="ignore", disconnected_inputs="ignore",
...@@ -404,9 +408,6 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): ...@@ -404,9 +408,6 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
# No differentiable arguments, return disconnected/null gradients # No differentiable arguments, return disconnected/null gradients
return arg_grads return arg_grads
if is_minimization:
implicit_f = grad(implicit_f, inner_x)
# Gradients are computed using the inner graph of the optimization op, not the actual inputs/outputs of the op. # Gradients are computed using the inner graph of the optimization op, not the actual inputs/outputs of the op.
packed_inner_args, packed_arg_shapes, implicit_f = pack_inputs_of_objective( packed_inner_args, packed_arg_shapes, implicit_f = pack_inputs_of_objective(
implicit_f, implicit_f,
...@@ -426,11 +427,11 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): ...@@ -426,11 +427,11 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
# Replace inner inputs (abstract dummies) with outer inputs (the actual user-provided symbols) # Replace inner inputs (abstract dummies) with outer inputs (the actual user-provided symbols)
# at the solution point. Innner arguments aren't needed anymore, delete them to avoid accidental references. # at the solution point. Innner arguments aren't needed anymore, delete them to avoid accidental references.
del inner_x del inner_x, inner_args
del inner_args
inner_to_outer_map = tuple(zip(fgraph.inputs, (x_star, *args))) inner_to_outer_map = tuple(zip(fgraph.inputs, (x_star, *args)))
df_dx_star, df_dtheta_star = graph_replace( df_dx_star, df_dtheta_star = graph_replace(
[df_dx, df_dtheta], inner_to_outer_map [df_dx, df_dtheta],
replace=inner_to_outer_map,
) )
if df_dtheta_star.ndim == 0 or df_dx_star.ndim == 0: if df_dtheta_star.ndim == 0 or df_dx_star.ndim == 0:
...@@ -455,19 +456,17 @@ class ScipyVectorWrapperOp(ScipyWrapperOp): ...@@ -455,19 +456,17 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
for i, (arg, to_diff) in enumerate(zip(args, args_to_diff)): for i, (arg, to_diff) in enumerate(zip(args, args_to_diff)):
if not to_diff: if not to_diff:
# Store the null grad we got from the initial `grad` call # Store the null grad we got from the initial `grad` call
null_grad = arg_grads[i] g = arg_grads[i]
assert isinstance(null_grad.type, NullType | DisconnectedType) assert isinstance(g.type, NullType | DisconnectedType)
final_grads.append(null_grad)
continue
arg_grad = next(grad_wrt_args_iter)
if arg_grad.ndim > 0 and output_grad.ndim > 0:
g = tensordot(output_grad, arg_grad, [[0], [0]])
else: else:
g = arg_grad * output_grad # Compute non-null grad and chain with output_grad
if isinstance(arg.type, ScalarType) and isinstance(g, TensorVariable): arg_grad = next(grad_wrt_args_iter)
g = scalar_from_tensor(g) if arg_grad.ndim > 0 and output_grad.ndim > 0:
g = tensordot(output_grad, arg_grad, [[0], [0]])
else:
g = arg_grad * output_grad
if isinstance(arg.type, ScalarType) and isinstance(g, TensorVariable):
g = scalar_from_tensor(g)
final_grads.append(g) final_grads.append(g)
assert next(grad_wrt_args_iter, None) is None, "Iterator was not exhausted" assert next(grad_wrt_args_iter, None) is None, "Iterator was not exhausted"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论