提交 394b355b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix OpFromGraph L_op with related and/or disconnected outputs

上级 2eb8fca2
...@@ -417,7 +417,10 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -417,7 +417,10 @@ class OpFromGraph(Op, HasInnerGraph):
FutureWarning, FutureWarning,
) )
self._lop_op_interface = False self._lop_op_interface = False
self._lop_op_cache: Callable | None = None # Dictionary where we cache OpFromGraph that represent the L_op
# A distinct OpFromGraph is needed to represent each pattern of output_grads connection
# It also returns a tuple that indicates which input_gradients are disconnected
self._lop_op_cache: dict[tuple[bool, ...], Callable] = {}
self._rop_op_cache: Callable | None = None self._rop_op_cache: Callable | None = None
self._connection_pattern = connection_pattern self._connection_pattern = connection_pattern
...@@ -480,24 +483,30 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -480,24 +483,30 @@ class OpFromGraph(Op, HasInnerGraph):
return outputs return outputs
@config.change_flags(compute_test_value="off") @config.change_flags(compute_test_value="off")
def _build_and_cache_lop_op(self) -> Callable: def _build_and_cache_lop_op(
"""converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance. self, disconnected_output_grads: tuple[bool, ...]
) -> Callable:
"""converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance,
specialized for the pattern of disconnected_output_grads
Results are cached in self._lop_op_cache Results are cached in self._lop_op_cache
""" """
if self._lop_op_cache is not None: try:
return self._lop_op_cache return self._lop_op_cache[disconnected_output_grads]
except KeyError:
pass
inner_inputs = self.inner_inputs inner_inputs = self.inner_inputs
inner_outputs = self.inner_outputs inner_outputs = self.inner_outputs
nin = len(inner_inputs) nin = len(inner_inputs)
nout = len(inner_outputs)
lop_overrides = ( lop_overrides = (
self.lop_overrides if self._lop_op_interface else self.grad_overrides self.lop_overrides if self._lop_op_interface else self.grad_overrides
) )
if isinstance(lop_overrides, OpFromGraph): if isinstance(lop_overrides, OpFromGraph):
if self._lop_op_interface: if self._lop_op_interface:
self._lop_op_cache = lop_overrides self._lop_op_cache[disconnected_output_grads] = lop_overrides
lop_overrides.kwargs["on_unused_input"] = "ignore" lop_overrides.kwargs["on_unused_input"] = "ignore"
return lop_overrides return lop_overrides
...@@ -507,20 +516,42 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -507,20 +516,42 @@ class OpFromGraph(Op, HasInnerGraph):
def lop_overrides(inps, grads): def lop_overrides(inps, grads):
return self.grad_overrides(*inps, *grads) return self.grad_overrides(*inps, *grads)
output_grads = [out_t() for out_t in self.output_types] # We try to compute the gradient with respect to connected outputs only
connected_inner_outputs = [
# We add an identity operation(copy) so that we don't override indirect
# gradient contributions to an inner output coming from other inner outputs
inner_out.copy()
for inner_out, disconnected in zip(
inner_outputs, disconnected_output_grads, strict=True
)
if not disconnected
]
connected_output_grads = [
out_t()
for out_t, disconnected in zip(
self.output_types, disconnected_output_grads, strict=True
)
if not disconnected
]
fn_grad = partial( fn_grad = partial(
grad, grad,
cost=None, cost=None,
disconnected_inputs="ignore", disconnected_inputs="ignore",
return_disconnected="disconnected", return_disconnected="disconnected",
null_gradients="return", null_gradients="return",
known_grads=dict(zip(inner_outputs, output_grads)), known_grads=dict(
zip(connected_inner_outputs, connected_output_grads, strict=True)
),
) )
if self._lop_op_interface: if self._lop_op_interface:
callable_args = (inner_inputs, inner_outputs, output_grads) callable_args = (
inner_inputs,
connected_inner_outputs,
connected_output_grads,
)
else: else:
callable_args = (inner_inputs, output_grads) callable_args = (inner_inputs, connected_output_grads)
# we need to convert _lop_op into an OfG instance # we need to convert _lop_op into an OfG instance
if lop_overrides is None: if lop_overrides is None:
...@@ -544,14 +575,15 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -544,14 +575,15 @@ class OpFromGraph(Op, HasInnerGraph):
else: else:
input_grads = self._call_custom_override(lop_overrides, callable_args, nin) input_grads = self._call_custom_override(lop_overrides, callable_args, nin)
# Filter out disconnected input and output gradients # Filter out disconnected/null input generated from the inner graph grad
# We append them in the outer wrapper function below
connected_input_grads = [ connected_input_grads = [
inp_grad inp_grad
for inp_grad in input_grads for inp_grad in input_grads
if not isinstance(inp_grad.type, DisconnectedType | NullType) if not isinstance(inp_grad.type, DisconnectedType | NullType)
] ]
lop_op = type(self)( lop_op = type(self)(
inputs=inner_inputs + inner_outputs + output_grads, inputs=inner_inputs + connected_inner_outputs + connected_output_grads,
outputs=connected_input_grads, outputs=connected_input_grads,
inline=self.is_inline, inline=self.is_inline,
name=(None if self.name is None else f"{self.name}_LOp"), name=(None if self.name is None else f"{self.name}_LOp"),
...@@ -559,9 +591,27 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -559,9 +591,27 @@ class OpFromGraph(Op, HasInnerGraph):
on_unused_input="ignore", on_unused_input="ignore",
) )
# Return a wrapper that combines connected and disconnected input gradients # Return a wrapper that combines connected and disconnected/null input gradients
# And also filters out disconnected/null output gradients
def wrapper(*inputs: Variable, **kwargs) -> list[Variable]: def wrapper(*inputs: Variable, **kwargs) -> list[Variable]:
connected_input_grads = iter(lop_op(*inputs, **kwargs)) inputs, outputs, output_grads = (
inputs[: -nout * 2],
inputs[-nout * 2 : -nout],
inputs[-nout:],
)
connected_outputs = [
output
for output, output_grad in zip(outputs, output_grads, strict=True)
if not isinstance(output_grad.type, DisconnectedType | NullType)
]
connected_output_grads = [
output_grad
for output_grad in output_grads
if not isinstance(output_grad.type, DisconnectedType)
]
connected_input_grads = iter(
lop_op(*inputs, *connected_outputs, *connected_output_grads, **kwargs)
)
return [ return [
input_grad input_grad
if isinstance(input_grad.type, DisconnectedType | NullType) if isinstance(input_grad.type, DisconnectedType | NullType)
...@@ -569,7 +619,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -569,7 +619,7 @@ class OpFromGraph(Op, HasInnerGraph):
for input_grad in input_grads for input_grad in input_grads
] ]
self._lop_op_cache = wrapper self._lop_op_cache[disconnected_output_grads] = wrapper
return wrapper return wrapper
@config.change_flags(compute_test_value="off") @config.change_flags(compute_test_value="off")
...@@ -652,7 +702,10 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -652,7 +702,10 @@ class OpFromGraph(Op, HasInnerGraph):
return wrapper return wrapper
def L_op(self, inputs, outputs, output_grads): def L_op(self, inputs, outputs, output_grads):
lop_op = self._build_and_cache_lop_op() disconnected_output_grads = tuple(
isinstance(og.type, DisconnectedType) for og in output_grads
)
lop_op = self._build_and_cache_lop_op(disconnected_output_grads)
return lop_op(*inputs, *outputs, *output_grads, return_list=True) return lop_op(*inputs, *outputs, *output_grads, return_list=True)
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
......
...@@ -8,7 +8,13 @@ from pytensor.compile import shared ...@@ -8,7 +8,13 @@ from pytensor.compile import shared
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, Rop, disconnected_type, grad from pytensor.gradient import (
DisconnectedType,
Rop,
disconnected_type,
grad,
verify_grad,
)
from pytensor.graph.basic import equal_computations from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType, null_type from pytensor.graph.null_type import NullType, null_type
...@@ -22,7 +28,15 @@ from pytensor.tensor.math import sum as pt_sum ...@@ -22,7 +28,15 @@ from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.random.utils import RandomStream from pytensor.tensor.random.utils import RandomStream
from pytensor.tensor.rewriting.shape import ShapeOptimizer from pytensor.tensor.rewriting.shape import ShapeOptimizer
from pytensor.tensor.shape import specify_shape from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors from pytensor.tensor.type import (
TensorType,
dscalars,
matrices,
matrix,
scalar,
vector,
vectors,
)
from tests import unittest_tools from tests import unittest_tools
from tests.graph.utils import MyVariable from tests.graph.utils import MyVariable
...@@ -638,6 +652,34 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -638,6 +652,34 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
out = test_ofg(y, y) out = test_ofg(y, y)
assert out.eval() == 4 assert out.eval() == 4
def test_L_op_disconnected_output_grad(self):
x, y = dscalars("x", "y")
rng = np.random.default_rng(594)
point = list(rng.normal(size=(2,)))
out1 = x + y
out2 = x * y
out3 = out1 * out2 # Create dependency between outputs
op = OpFromGraph([x, y], [out1, out2, out3])
verify_grad(lambda x, y: pt.add(*op(x, y)), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[:-1]), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[1:]), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[::2]), point, rng=rng)
verify_grad(lambda x, y: op(x, y)[0], point, rng=rng)
verify_grad(lambda x, y: op(x, y)[1], point, rng=rng)
verify_grad(lambda x, y: op(x, y)[2], point, rng=rng)
# Test disconnected graphs are handled correctly
op = OpFromGraph([x, y], [x**2, y**3])
with pytest.warns(UserWarning):
grad_x_wrt_y = grad(
op(x, y)[0],
wrt=y,
return_disconnected="disconnected",
disconnected_inputs="warn",
)
assert isinstance(grad_x_wrt_y.type, DisconnectedType)
def test_repeated_inputs(self): def test_repeated_inputs(self):
x = pt.dscalar("x") x = pt.dscalar("x")
y = pt.dscalar("y") y = pt.dscalar("y")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论