提交 e65b0c51 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Make OpFromGraph.make_node interface consistent with its Apply nodes

上级 6ef1452a
......@@ -375,6 +375,11 @@ class OpFromGraph(Op, HasInnerGraph):
self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
self.lop_overrides = lop_overrides
self.grad_overrides = grad_overrides
self.rop_overrides = rop_overrides
if lop_overrides != "default":
if grad_overrides != "default":
raise ValueError(
......@@ -732,19 +737,71 @@ class OpFromGraph(Op, HasInnerGraph):
]
return ret_l
def __call__(self, *inputs, **kwargs):
# The user interface doesn't expect the shared variable inputs of the
# inner-graph, but, since `Op.make_node` does (and `Op.__call__`
# dispatches to `Op.make_node`), we need to compensate here
num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)
if len(inputs) == num_expected_inps:
actual_inputs = inputs + tuple(self.shared_inputs)
return super().__call__(*actual_inputs, **kwargs)
elif len(inputs) == len(self.inner_inputs):
return super().__call__(*inputs, **kwargs)
else:
raise ValueError(f"Expected at least {num_expected_inps} input(s)")
def make_node(self, *inputs):
# The `inputs` received here should correspond to the inputs in the
# `Apply` nodes we produce below
if len(inputs) != len(self.inner_inputs):
raise ValueError(f"Expected {len(self.inner_inputs)} input(s)")
num_expected_inps = len(self.inner_inputs) - len(self.shared_inputs)
if len(inputs) != num_expected_inps:
raise ValueError(
f"Expected {int(num_expected_inps)} inputs, got {len(inputs)}"
)
inputs = [
inp_t.filter_variable(inp) for inp, inp_t in zip(inputs, self.input_types)
non_shared_inputs = inputs[:num_expected_inps]
non_shared_inputs = [
inp_t.filter_variable(inp)
for inp, inp_t in zip(non_shared_inputs, self.input_types)
]
shared_inputs = inputs[num_expected_inps:]
local_shared_inputs = self.inner_inputs[num_expected_inps:]
inner_and_input_shareds = list(zip(local_shared_inputs, shared_inputs))
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
# The shared variables are not equal to the original shared
# variables, so we construct a new `Op` that uses the new shared
# variables instead
replace = {
old_inp: new_inp for old_inp, new_inp in zip(self.inner_inputs, inputs)
}
replace.update(inner_and_input_shareds)
# If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step
new_outputs = clone_replace(
self.inner_outputs, replace=replace, share_inputs=True
)
new_op = type(self)(
inputs=non_shared_inputs,
outputs=new_outputs,
inline=self.is_inline,
lop_overrides=self.lop_overrides,
grad_overrides=self.grad_overrides,
rop_overrides=self.rop_overrides,
connection_pattern=self._connection_pattern,
name=self.name,
)
else:
new_op = self
apply_node = Apply(
self,
list(inputs) + self.shared_inputs,
[type() for type in self.output_types],
new_op,
list(non_shared_inputs) + new_op.shared_inputs,
[type() for type in new_op.output_types],
)
return apply_node
......
......@@ -3,6 +3,7 @@ from functools import partial
import numpy as np
import pytest
import aesara.tensor as at
from aesara.compile import shared
from aesara.compile.builders import OpFromGraph
from aesara.compile.function import function
......@@ -29,6 +30,12 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
def test_valid_input(self):
x, y, z = matrices("xyz")
with pytest.raises(ValueError, match="Expected at least.*"):
OpFromGraph([x], [x])()
with pytest.raises(ValueError, match=r"Expected 1 input\(s\)"):
OpFromGraph([x], [x]).make_node()
with pytest.raises(TypeError):
OpFromGraph((x,), (x,))
......@@ -451,6 +458,39 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
grad_f = grad(f, y)
assert grad_f.tag.test_value is not None
def test_make_node_shared(self):
"""Make sure we can provide `OpFromGraph.make_node` new shared inputs and get a valid `OpFromGraph`."""
x = at.scalar("x")
y = shared(1.0, name="y")
test_ofg = OpFromGraph([x], [x + y])
assert test_ofg.inputs == [x]
assert test_ofg.shared_inputs == [y]
out = test_ofg(x)
y_clone = y.clone()
assert y_clone != y
y_clone.name = "y_clone"
out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0]
assert out_new.owner.op.inputs == [x]
assert out_new.owner.op.shared_inputs == [y_clone]
out_fn = function([x], out_new)
assert np.array_equal(out_fn(1.0), 2.0)
y_clone.set_value(2.0)
assert np.array_equal(out_fn(1.0), 3.0)
# This should also work, because the containers are the same:
# y.set_value(1.0)
# assert np.array_equal(out_fn(1.0), 2.0)
def test_debugprint():
x, y, z = matrices("xyz")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论