提交 d4f87436 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Fix connection_pattern for Minimize ops

上级 91b73211
......@@ -213,22 +213,6 @@ class ScipyWrapperOp(Op, HasInnerGraph):
return Apply(self, inputs, [self.inner_inputs[0].type(), ps.bool("success")])
def connection_pattern(self, node=None):
"""
All Ops that inherit from ScipyWrapperOp share the same connection pattern logic, because they all share the
same output structure. There are two outputs: the optimized variable, and a success flag. The success flag is
not differentiable, so it is never connected. The optimized variable is connected only to inputs that are
both connected to the objective function and of float dtype.
"""
fgraph = self.fgraph
fx = fgraph.outputs[0]
return [
# Every input is disonnected to the second output (success)
# And may or not be connected to the first output (opt_x)
[connected, False]
for [connected] in io_connection_pattern(fgraph.inputs, [fx])
]
class ScipyScalarWrapperOp(ScipyWrapperOp):
def build_fn(self):
......@@ -474,6 +458,26 @@ class ScipyVectorWrapperOp(ScipyWrapperOp):
return final_grads
def _optimizer_connection_pattern(fgraph, is_minimization):
"""Compute connection pattern for scipy optimization Ops.
There are two outputs: the optimized variable, and a success flag. The success flag is
not differentiable, so it is never connected. The optimized variable is connected only
to inputs that are connected to the implicit function used in the gradient computation.
For minimization, the implicit function is grad(objective, x), not the objective itself.
An input may be connected to the objective but disconnected from its gradient (e.g. an
additive constant), so the connection pattern must reflect the actual implicit function.
"""
inner_x = fgraph.inputs[0]
fx = fgraph.outputs[0]
if is_minimization:
fx = grad(fx, inner_x)
return [
[connected, False] for [connected] in io_connection_pattern(fgraph.inputs, [fx])
]
class MinimizeScalarOp(ScipyScalarWrapperOp):
def __init__(
self,
......@@ -502,6 +506,9 @@ class MinimizeScalarOp(ScipyScalarWrapperOp):
self._fn = None
self._fn_wrapped = None
def connection_pattern(self, node=None):
return _optimizer_connection_pattern(self.fgraph, is_minimization=True)
def __str__(self):
return f"{self.__class__.__name__}(method={self.method})"
......@@ -638,6 +645,9 @@ class MinimizeOp(ScipyVectorWrapperOp):
self._fn = None
self._fn_wrapped = None
def connection_pattern(self, node=None):
return _optimizer_connection_pattern(self.fgraph, is_minimization=True)
def __str__(self):
str_args = ", ".join(
[
......@@ -833,6 +843,9 @@ class RootScalarOp(ScipyScalarWrapperOp):
self._fn = None
self._fn_wrapped = None
def connection_pattern(self, node=None):
return _optimizer_connection_pattern(self.fgraph, is_minimization=False)
def __str__(self):
str_args = ", ".join(
[f"{arg}={getattr(self, arg)}" for arg in ["method", "jac", "hess"]]
......@@ -979,6 +992,9 @@ class RootOp(ScipyVectorWrapperOp):
self._fn = None
self._fn_wrapped = None
def connection_pattern(self, node=None):
return _optimizer_connection_pattern(self.fgraph, is_minimization=False)
def __str__(self):
str_args = ", ".join(
[f"{arg}={getattr(self, arg)}" for arg in ["method", "jac"]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论