提交 e1809275 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix bug where ShapeFeature would create circular shape graph

上级 f951743d
...@@ -423,7 +423,17 @@ class ShapeFeature(Feature): ...@@ -423,7 +423,17 @@ class ShapeFeature(Feature):
# This mean the shape is equivalent # This mean the shape is equivalent
# We do not want to do the ancestor check in those cases # We do not want to do the ancestor check in those cases
merged_shape.append(r_shape[i]) merged_shape.append(r_shape[i])
elif r_shape[i] in ancestors([other_shape[i]]): elif any(
(
r_shape[i] == anc
or (
anc.owner
and isinstance(anc.owner.op, Shape)
and anc.owner.inputs[0] == r
)
)
for anc in ancestors([other_shape[i]])
):
# Another case where we want to use r_shape[i] is when # Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case, # other_shape[i] actually depends on r_shape[i]. In that case,
# we do not want to substitute an expression with another that # we do not want to substitute an expression with another that
......
...@@ -15,7 +15,7 @@ from pytensor.graph.op import Op ...@@ -15,7 +15,7 @@ from pytensor.graph.op import Op
from pytensor.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in from pytensor.graph.rewriting.basic import check_stack_trace, node_rewriter, out2in
from pytensor.graph.rewriting.utils import rewrite_graph from pytensor.graph.rewriting.utils import rewrite_graph
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.tensor.basic import as_tensor_variable from pytensor.tensor.basic import alloc, as_tensor_variable
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.math import add, exp, maximum from pytensor.tensor.math import add, exp, maximum
from pytensor.tensor.rewriting.basic import register_specialize from pytensor.tensor.rewriting.basic import register_specialize
...@@ -239,6 +239,25 @@ class TestShapeRewriter: ...@@ -239,6 +239,25 @@ class TestShapeRewriter:
# FIXME: This is not a good test. # FIXME: This is not a good test.
f([[1, 2], [2, 3]]) f([[1, 2], [2, 3]])
def test_shape_of_useless_alloc(self):
"""Test that local_shape_to_shape_i does not create circular graph.
Regression test for #565
"""
alpha = vector(shape=(None,), dtype="float64")
channel = vector(shape=(None,), dtype="float64")
broadcast_channel = alloc(
channel,
maximum(
shape(alpha)[0],
shape(channel)[0],
),
)
out = shape(broadcast_channel)
fn = function([alpha, channel], out)
assert fn([1.0, 2, 3], [1.0, 2, 3]) == (3,)
class TestReshape: class TestReshape:
def setup_method(self): def setup_method(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论