提交 9a65fcd7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow shared variable changes in OpFromGraph inputs

上级 e7d77b00
......@@ -771,9 +771,8 @@ class OpFromGraph(Op, HasInnerGraph):
for inp, inp_t in zip(non_shared_inputs, self.input_types)
]
inner_and_input_shareds = list(
zip(self.shared_inputs, inputs[num_expected_inps:])
)
new_shared_inputs = inputs[num_expected_inps:]
inner_and_input_shareds = list(zip(self.shared_inputs, new_shared_inputs))
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
# The shared variables are not equal to the original shared
......@@ -789,13 +788,23 @@ class OpFromGraph(Op, HasInnerGraph):
# If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step
new_outputs = clone_replace(
new_inner_outputs = clone_replace(
self.outputs, replace=replace, share_inputs=True
)
# `self.inputs` should not contain any shared variables, so we know
# that those are inputs to `new_outputs`, because we chose not to
# clone inputs; however, it's possible that the new shared variable
# inputs aren't actually shared variables. When they aren't we
# need to add them as new inputs.
unshared_inputs = [
inp for inp in new_shared_inputs if not isinstance(inp, SharedVariable)
]
new_inner_inputs = self.inputs + unshared_inputs
new_op = type(self)(
inputs=self.inputs,
outputs=new_outputs,
inputs=new_inner_inputs,
outputs=new_inner_outputs,
inline=self.is_inline,
lop_overrides=self.lop_overrides,
grad_overrides=self.grad_overrides,
......@@ -803,12 +812,16 @@ class OpFromGraph(Op, HasInnerGraph):
connection_pattern=self._connection_pattern,
name=self.name,
)
new_inputs = (
list(non_shared_inputs) + unshared_inputs + new_op.shared_inputs
)
else:
new_op = self
new_inputs = list(non_shared_inputs) + new_op.shared_inputs
apply_node = Apply(
new_op,
list(non_shared_inputs) + new_op.shared_inputs,
new_inputs,
[type() for type in new_op.output_types],
)
return apply_node
......
......@@ -510,6 +510,29 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
with pytest.raises(MissingInputError):
OpFromGraph([], [x])
def test_shared_to_nonshared_input(self):
"""Make sure that shared variables can be replaced with non-shared variables."""
x = at.scalar("x")
y = shared(1.0, name="y")
test_ofg = OpFromGraph([], [y])
assert test_ofg.inputs == []
assert test_ofg.shared_inputs == [y]
out_1_fn = function([], test_ofg())
res_1 = out_1_fn()
assert np.array_equal(res_1, 1.0)
test_ofg_new = test_ofg.make_node(x)
assert test_ofg_new.op.inputs == [x]
assert test_ofg_new.op.shared_inputs == []
out_2_fn = function([x], test_ofg_new.outputs[0])
res_2 = out_2_fn(np.array(1.0, dtype=config.floatX))
assert np.array_equal(res_2, 1.0)
def test_debugprint():
x, y, z = matrices("xyz")
......
......@@ -2,7 +2,7 @@ import numpy as np
import pytest
import aesara
import aesara.tensor.basic as at
import aesara.tensor as at
from aesara import function, scan, shared
from aesara.compile.builders import OpFromGraph
from aesara.compile.io import In
......@@ -573,6 +573,24 @@ class TestPushOutNonSeqScan:
res = out_fn()
assert np.array_equal(res, np.repeat(3.0, 10))
def test_nested_OpFromGraph_shared(self):
y = aesara.shared(1.0, name="y")
test_ofg = OpFromGraph([], [y])
def inner_func(x):
out, _ = aesara.scan(lambda: test_ofg(), n_steps=x)
return out
out, _ = aesara.scan(inner_func, sequences=[at.arange(1, 2)])
_ = aesara.function([], test_ofg())
out_fn = aesara.function([], out)
assert np.array_equal(out_fn(), [[1.0]])
class TestPushOutAddScan:
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论