提交 e65b0c51 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make OpFromGraph.make_node interface consistent with its Apply nodes

上级 6ef1452a
...@@ -375,6 +375,11 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -375,6 +375,11 @@ class OpFromGraph(Op, HasInnerGraph):
self.kwargs = kwargs self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs] self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs] self.output_types = [out.type for out in outputs]
self.lop_overrides = lop_overrides
self.grad_overrides = grad_overrides
self.rop_overrides = rop_overrides
if lop_overrides != "default": if lop_overrides != "default":
if grad_overrides != "default": if grad_overrides != "default":
raise ValueError( raise ValueError(
...@@ -732,19 +737,71 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -732,19 +737,71 @@ class OpFromGraph(Op, HasInnerGraph):
] ]
return ret_l return ret_l
def __call__(self, *inputs, **kwargs):
# The user interface doesn't expect the shared variable inputs of the
# inner-graph, but, since `Op.make_node` does (and `Op.__call__`
# dispatches to `Op.make_node`), we need to compensate here
num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)
if len(inputs) == num_expected_inps:
actual_inputs = inputs + tuple(self.shared_inputs)
return super().__call__(*actual_inputs, **kwargs)
elif len(inputs) == len(self.inner_inputs):
return super().__call__(*inputs, **kwargs)
else:
raise ValueError(f"Expected at least {num_expected_inps} input(s)")
def make_node(self, *inputs): def make_node(self, *inputs):
# The `inputs` received here should correspond to the inputs in the
# `Apply` nodes we produce below
if len(inputs) != len(self.inner_inputs):
raise ValueError(f"Expected {len(self.inner_inputs)} input(s)")
num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs) num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)
if len(inputs) != num_expected_inps: non_shared_inputs = inputs[:num_expected_inps]
raise ValueError(
f"Expected {int(num_expected_inps)} inputs, got {len(inputs)}" non_shared_inputs = [
) inp_t.filter_variable(inp)
inputs = [ for inp, inp_t in zip(non_shared_inputs, self.input_types)
inp_t.filter_variable(inp) for inp, inp_t in zip(inputs, self.input_types)
] ]
shared_inputs = inputs[num_expected_inps:]
local_shared_inputs = self.inner_inputs[num_expected_inps:]
inner_and_input_shareds = list(zip(local_shared_inputs, shared_inputs))
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
# The shared variables are not equal to the original shared
# variables, so we construct a new `Op` that uses the new shared
# variables instead
replace = {
old_inp: new_inp for old_inp, new_inp in zip(self.inner_inputs, inputs)
}
replace.update(inner_and_input_shareds)
# If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step
new_outputs = clone_replace(
self.inner_outputs, replace=replace, share_inputs=True
)
new_op = type(self)(
inputs=non_shared_inputs,
outputs=new_outputs,
inline=self.is_inline,
lop_overrides=self.lop_overrides,
grad_overrides=self.grad_overrides,
rop_overrides=self.rop_overrides,
connection_pattern=self._connection_pattern,
name=self.name,
)
else:
new_op = self
apply_node = Apply( apply_node = Apply(
self, new_op,
list(inputs) + self.shared_inputs, list(non_shared_inputs) + new_op.shared_inputs,
[type() for type in self.output_types], [type() for type in new_op.output_types],
) )
return apply_node return apply_node
......
...@@ -3,6 +3,7 @@ from functools import partial ...@@ -3,6 +3,7 @@ from functools import partial
import numpy as np import numpy as np
import pytest import pytest
import aesara.tensor as at
from aesara.compile import shared from aesara.compile import shared
from aesara.compile.builders import OpFromGraph from aesara.compile.builders import OpFromGraph
from aesara.compile.function import function from aesara.compile.function import function
...@@ -29,6 +30,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -29,6 +30,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
def test_valid_input(self): def test_valid_input(self):
x, y, z = matrices("xyz") x, y, z = matrices("xyz")
with pytest.raises(ValueError, match="Expected at least.*"):
OpFromGraph([x], [x])()
with pytest.raises(ValueError, match=r"Expected 1 input\(s\)"):
OpFromGraph([x], [x]).make_node()
with pytest.raises(TypeError): with pytest.raises(TypeError):
OpFromGraph((x,), (x,)) OpFromGraph((x,), (x,))
...@@ -451,6 +458,39 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -451,6 +458,39 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
grad_f = grad(f, y) grad_f = grad(f, y)
assert grad_f.tag.test_value is not None assert grad_f.tag.test_value is not None
def test_make_node_shared(self):
"""Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`."""
x = at.scalar("x")
y = shared(1.0, name="y")
test_ofg = OpFromGraph([x], [x + y])
assert test_ofg.inputs == [x]
assert test_ofg.shared_inputs == [y]
out = test_ofg(x)
y_clone = y.clone()
assert y_clone != y
y_clone.name = "y_clone"
out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0]
assert out_new.owner.op.inputs == [x]
assert out_new.owner.op.shared_inputs == [y_clone]
out_fn = function([x], out_new)
assert np.array_equal(out_fn(1.0), 2.0)
y_clone.set_value(2.0)
assert np.array_equal(out_fn(1.0), 3.0)
# This should also work, because the containers are the same:
# y.set_value(1.0)
# assert np.array_equal(out_fn(1.0), 2.0)
def test_debugprint(): def test_debugprint():
x, y, z = matrices("xyz") x, y, z = matrices("xyz")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论