提交 f43b92f3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix shared variable comparisons in OpFromGraph.make_node

上级 bc10e2b9
...@@ -765,28 +765,30 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -765,28 +765,30 @@ class OpFromGraph(Op, HasInnerGraph):
for inp, inp_t in zip(non_shared_inputs, self.input_types) for inp, inp_t in zip(non_shared_inputs, self.input_types)
] ]
shared_inputs = inputs[num_expected_inps:] inner_and_input_shareds = list(
local_shared_inputs = self.inner_inputs[num_expected_inps:] zip(self.shared_inputs, inputs[num_expected_inps:])
)
inner_and_input_shareds = list(zip(local_shared_inputs, shared_inputs))
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds): if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
# The shared variables are not equal to the original shared # The shared variables are not equal to the original shared
# variables, so we construct a new `Op` that uses the new shared # variables, so we construct a new `Op` that uses the new shared
# variables instead # variables instead.
replace = { # All this is really doing is making the unused (internally, at
old_inp: new_inp for old_inp, new_inp in zip(self.inner_inputs, inputs) # least) `self.outputs` and `self.shared_inputs` consistent.
} # We could just as easily `copy` this `Op`, update
replace.update(inner_and_input_shareds) # `self.shared_inputs`, and avoid cloning anything, but this is a
# more "change-proof" approach, because it still work when/if those
# attributes end up being used.
replace = dict(inner_and_input_shareds)
# If the new shared variables are inconsistent with the inner-graph, # If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step # such errors should arise in this step
new_outputs = clone_replace( new_outputs = clone_replace(
self.inner_outputs, replace=replace, share_inputs=True self.outputs, replace=replace, share_inputs=True
) )
new_op = type(self)( new_op = type(self)(
inputs=non_shared_inputs, inputs=self.inputs,
outputs=new_outputs, outputs=new_outputs,
inline=self.is_inline, inline=self.is_inline,
lop_overrides=self.lop_overrides, lop_overrides=self.lop_overrides,
......
...@@ -480,17 +480,29 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -480,17 +480,29 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert out_new.owner.op.shared_inputs == [y_clone] assert out_new.owner.op.shared_inputs == [y_clone]
out_fn = function([x], out_new) out_fn = function([x], out_new)
assert np.array_equal(out_fn(1.0), 2.0) assert np.array_equal(out_fn(1.0), 2.0)
y_clone.set_value(2.0) y_clone.set_value(2.0)
assert np.array_equal(out_fn(1.0), 3.0) assert np.array_equal(out_fn(1.0), 3.0)
# This should also work, because the containers are the same: # This should also work, because the containers are the same:
# y.set_value(1.0) # y.set_value(1.0)
# assert np.array_equal(out_fn(1.0), 2.0) # assert np.array_equal(out_fn(1.0), 2.0)
def test_shared_with_constant_input(self):
"""Make sure that a constant input can be given to an `OpFromGraph` instance."""
x = at.scalar("x")
y = shared(1.0, name="y")
test_ofg = OpFromGraph([x], [x + y])
assert test_ofg.inputs == [x]
assert test_ofg.shared_inputs == [y]
out = test_ofg(at.as_tensor(1.0, dtype=config.floatX))
out_fn = function([], out)
assert np.array_equal(out_fn(), 2.0)
def test_debugprint(): def test_debugprint():
x, y, z = matrices("xyz") x, y, z = matrices("xyz")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论