提交 f43b92f3 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix shared variable comparisons in OpFromGraph.make_node

上级 bc10e2b9
......@@ -765,28 +765,30 @@ class OpFromGraph(Op, HasInnerGraph):
for inp, inp_t in zip(non_shared_inputs, self.input_types)
]
shared_inputs = inputs[num_expected_inps:]
local_shared_inputs = self.inner_inputs[num_expected_inps:]
inner_and_input_shareds = list(zip(local_shared_inputs, shared_inputs))
inner_and_input_shareds = list(
zip(self.shared_inputs, inputs[num_expected_inps:])
)
if not all(inp_s == inn_s for inn_s, inp_s in inner_and_input_shareds):
# The shared variables are not equal to the original shared
# variables, so we construct a new `Op` that uses the new shared
# variables instead
replace = {
old_inp: new_inp for old_inp, new_inp in zip(self.inner_inputs, inputs)
}
replace.update(inner_and_input_shareds)
# variables instead.
# All this is really doing is making the unused (internally, at
# least) `self.outputs` and `self.shared_inputs` consistent.
# We could just as easily `copy` this `Op`, update
# `self.shared_inputs`, and avoid cloning anything, but this is a
# more "change-proof" approach, because it still work when/if those
# attributes end up being used.
replace = dict(inner_and_input_shareds)
# If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step
new_outputs = clone_replace(
self.inner_outputs, replace=replace, share_inputs=True
self.outputs, replace=replace, share_inputs=True
)
new_op = type(self)(
inputs=non_shared_inputs,
inputs=self.inputs,
outputs=new_outputs,
inline=self.is_inline,
lop_overrides=self.lop_overrides,
......
......@@ -480,17 +480,29 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert out_new.owner.op.shared_inputs == [y_clone]
out_fn = function([x], out_new)
assert np.array_equal(out_fn(1.0), 2.0)
y_clone.set_value(2.0)
assert np.array_equal(out_fn(1.0), 3.0)
# This should also work, because the containers are the same:
# y.set_value(1.0)
# assert np.array_equal(out_fn(1.0), 2.0)
def test_shared_with_constant_input(self):
"""Make sure that a constant input can be given to an `OpFromGraph` instance."""
x = at.scalar("x")
y = shared(1.0, name="y")
test_ofg = OpFromGraph([x], [x + y])
assert test_ofg.inputs == [x]
assert test_ofg.shared_inputs == [y]
out = test_ofg(at.as_tensor(1.0, dtype=config.floatX))
out_fn = function([], out)
assert np.array_equal(out_fn(), 2.0)
def test_debugprint():
x, y, z = matrices("xyz")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论