提交 1a85652e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix type consistency issue in Apply.clone_with_new_inputs

上级 ece0fb1d
......@@ -249,10 +249,14 @@ class Apply(Node):
for i, (curr, new) in enumerate(zip(self.inputs, new_inputs)):
if curr.type != new.type:
if strict:
# If compatible, casts new into curr.type
new_inputs[i] = curr.type.filter_variable(new)
new_i = curr.type.filter_variable(new)
new_inputs[i] = new_i
if curr.type != new_i.type:
remake_node = True
else:
remake_node = True
if remake_node:
new_node = self.op.make_node(*new_inputs)
new_node.tag = copy(self.tag).__update__(new_node.tag)
......
......@@ -662,3 +662,35 @@ class TestCloneReplace:
utt.assert_allclose(
test(x, at.sum((x + 1) ** 2), mention_y=True), 1.21000003815
)
def test_clone_new_inputs():
"""Make sure that `Apply.clone_with_new_inputs` properly handles `Type` changes."""
x = at.tensor(np.float64, shape=(None,))
y = at.tensor(np.float64, shape=(1,))
z = at.add(x, y)
assert z.type.shape == (None,)
x_new = at.tensor(np.float64, shape=(1,))
# The output nodes should be reconstructed, because the input types' static
# shape information increased in specificity
z_node_new = z.owner.clone_with_new_inputs([x_new, y])
assert z_node_new.outputs[0].type.shape == (1,)
assert z_node_new.inputs[0].type.shape == (1,)
assert z_node_new.inputs[1].type.shape == (1,)
# Now, attempt to decrease the specificity of the first input's static
# shape information, but, because we're using strict conversion, we
# shouldn't lose any information
z = at.add(x_new, y)
assert z.type.shape == (1,)
z_node_new = z.owner.clone_with_new_inputs([x, y], strict=True)
assert z_node_new.outputs[0].type.shape == (1,)
assert z_node_new.inputs[0].type.shape == (1,)
assert z_node_new.inputs[1].type.shape == (1,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论