提交 394b355b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix OpFromGraph L_op with related and/or disconnected outputs

上级 2eb8fca2
......@@ -417,7 +417,10 @@ class OpFromGraph(Op, HasInnerGraph):
FutureWarning,
)
self._lop_op_interface = False
self._lop_op_cache: Callable | None = None
# Dictionary where we cache OpFromGraph that represent the L_op
# A distinct OpFromGraph is needed to represent each pattern of output_grads connection
# It also returns a tuple that indicates which input_gradients are disconnected
self._lop_op_cache: dict[tuple[bool, ...], Callable] = {}
self._rop_op_cache: Callable | None = None
self._connection_pattern = connection_pattern
......@@ -480,24 +483,30 @@ class OpFromGraph(Op, HasInnerGraph):
return outputs
@config.change_flags(compute_test_value="off")
def _build_and_cache_lop_op(self) -> Callable:
"""converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance.
def _build_and_cache_lop_op(
self, disconnected_output_grads: tuple[bool, ...]
) -> Callable:
"""converts lop_overrides (or grad_overrides) from user supplied form to type(self) instance,
specialized for the pattern of disconnected_output_grads
Results are cached in self._lop_op_cache
"""
if self._lop_op_cache is not None:
return self._lop_op_cache
try:
return self._lop_op_cache[disconnected_output_grads]
except KeyError:
pass
inner_inputs = self.inner_inputs
inner_outputs = self.inner_outputs
nin = len(inner_inputs)
nout = len(inner_outputs)
lop_overrides = (
self.lop_overrides if self._lop_op_interface else self.grad_overrides
)
if isinstance(lop_overrides, OpFromGraph):
if self._lop_op_interface:
self._lop_op_cache = lop_overrides
self._lop_op_cache[disconnected_output_grads] = lop_overrides
lop_overrides.kwargs["on_unused_input"] = "ignore"
return lop_overrides
......@@ -507,20 +516,42 @@ class OpFromGraph(Op, HasInnerGraph):
def lop_overrides(inps, grads):
return self.grad_overrides(*inps, *grads)
output_grads = [out_t() for out_t in self.output_types]
# We try to compute the gradient with respect to connected outputs only
connected_inner_outputs = [
# We add an identity operation(copy) so that we don't override indirect
# gradient contributions to an inner output coming from other inner outputs
inner_out.copy()
for inner_out, disconnected in zip(
inner_outputs, disconnected_output_grads, strict=True
)
if not disconnected
]
connected_output_grads = [
out_t()
for out_t, disconnected in zip(
self.output_types, disconnected_output_grads, strict=True
)
if not disconnected
]
fn_grad = partial(
grad,
cost=None,
disconnected_inputs="ignore",
return_disconnected="disconnected",
null_gradients="return",
known_grads=dict(zip(inner_outputs, output_grads)),
known_grads=dict(
zip(connected_inner_outputs, connected_output_grads, strict=True)
),
)
if self._lop_op_interface:
callable_args = (inner_inputs, inner_outputs, output_grads)
callable_args = (
inner_inputs,
connected_inner_outputs,
connected_output_grads,
)
else:
callable_args = (inner_inputs, output_grads)
callable_args = (inner_inputs, connected_output_grads)
# we need to convert _lop_op into an OfG instance
if lop_overrides is None:
......@@ -544,14 +575,15 @@ class OpFromGraph(Op, HasInnerGraph):
else:
input_grads = self._call_custom_override(lop_overrides, callable_args, nin)
# Filter out disconnected input and output gradients
# Filter out disconnected/null input generated from the inner graph grad
# We append them in the outer wrapper function below
connected_input_grads = [
inp_grad
for inp_grad in input_grads
if not isinstance(inp_grad.type, DisconnectedType | NullType)
]
lop_op = type(self)(
inputs=inner_inputs + inner_outputs + output_grads,
inputs=inner_inputs + connected_inner_outputs + connected_output_grads,
outputs=connected_input_grads,
inline=self.is_inline,
name=(None if self.name is None else f"{self.name}_LOp"),
......@@ -559,9 +591,27 @@ class OpFromGraph(Op, HasInnerGraph):
on_unused_input="ignore",
)
# Return a wrapper that combines connected and disconnected input gradients
# Return a wrapper that combines connected and disconnected/null input gradients
# And also filters out disconnected/null output gradients
def wrapper(*inputs: Variable, **kwargs) -> list[Variable]:
connected_input_grads = iter(lop_op(*inputs, **kwargs))
inputs, outputs, output_grads = (
inputs[: -nout * 2],
inputs[-nout * 2 : -nout],
inputs[-nout:],
)
connected_outputs = [
output
for output, output_grad in zip(outputs, output_grads, strict=True)
if not isinstance(output_grad.type, DisconnectedType | NullType)
]
connected_output_grads = [
output_grad
for output_grad in output_grads
if not isinstance(output_grad.type, DisconnectedType)
]
connected_input_grads = iter(
lop_op(*inputs, *connected_outputs, *connected_output_grads, **kwargs)
)
return [
input_grad
if isinstance(input_grad.type, DisconnectedType | NullType)
......@@ -569,7 +619,7 @@ class OpFromGraph(Op, HasInnerGraph):
for input_grad in input_grads
]
self._lop_op_cache = wrapper
self._lop_op_cache[disconnected_output_grads] = wrapper
return wrapper
@config.change_flags(compute_test_value="off")
......@@ -652,7 +702,10 @@ class OpFromGraph(Op, HasInnerGraph):
return wrapper
def L_op(self, inputs, outputs, output_grads):
lop_op = self._build_and_cache_lop_op()
disconnected_output_grads = tuple(
isinstance(og.type, DisconnectedType) for og in output_grads
)
lop_op = self._build_and_cache_lop_op(disconnected_output_grads)
return lop_op(*inputs, *outputs, *output_grads, return_list=True)
def R_op(self, inputs, eval_points):
......
......@@ -8,7 +8,13 @@ from pytensor.compile import shared
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
from pytensor.configdefaults import config
from pytensor.gradient import DisconnectedType, Rop, disconnected_type, grad
from pytensor.gradient import (
DisconnectedType,
Rop,
disconnected_type,
grad,
verify_grad,
)
from pytensor.graph.basic import equal_computations
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.null_type import NullType, null_type
......@@ -22,7 +28,15 @@ from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.random.utils import RandomStream
from pytensor.tensor.rewriting.shape import ShapeOptimizer
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.type import TensorType, matrices, matrix, scalar, vector, vectors
from pytensor.tensor.type import (
TensorType,
dscalars,
matrices,
matrix,
scalar,
vector,
vectors,
)
from tests import unittest_tools
from tests.graph.utils import MyVariable
......@@ -638,6 +652,34 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
out = test_ofg(y, y)
assert out.eval() == 4
def test_L_op_disconnected_output_grad(self):
x, y = dscalars("x", "y")
rng = np.random.default_rng(594)
point = list(rng.normal(size=(2,)))
out1 = x + y
out2 = x * y
out3 = out1 * out2 # Create dependency between outputs
op = OpFromGraph([x, y], [out1, out2, out3])
verify_grad(lambda x, y: pt.add(*op(x, y)), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[:-1]), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[1:]), point, rng=rng)
verify_grad(lambda x, y: pt.add(*op(x, y)[::2]), point, rng=rng)
verify_grad(lambda x, y: op(x, y)[0], point, rng=rng)
verify_grad(lambda x, y: op(x, y)[1], point, rng=rng)
verify_grad(lambda x, y: op(x, y)[2], point, rng=rng)
# Test disconnected graphs are handled correctly
op = OpFromGraph([x, y], [x**2, y**3])
with pytest.warns(UserWarning):
grad_x_wrt_y = grad(
op(x, y)[0],
wrt=y,
return_disconnected="disconnected",
disconnected_inputs="warn",
)
assert isinstance(grad_x_wrt_y.type, DisconnectedType)
def test_repeated_inputs(self):
x = pt.dscalar("x")
y = pt.dscalar("y")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论