提交 fe5865ef authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove assert in local_useless_alloc

Rewrite was already tagged as "shape_unsafe"
上级 c855a6d8
...@@ -67,9 +67,7 @@ from pytensor.tensor.basic import ( ...@@ -67,9 +67,7 @@ from pytensor.tensor.basic import (
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.extra_ops import broadcast_arrays from pytensor.tensor.extra_ops import broadcast_arrays
from pytensor.tensor.math import Sum, add from pytensor.tensor.math import Sum, add, eq
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import eq
from pytensor.tensor.shape import Shape_i, shape_padleft from pytensor.tensor.shape import Shape_i, shape_padleft
from pytensor.tensor.sort import TopKOp from pytensor.tensor.sort import TopKOp
from pytensor.tensor.type import DenseTensorType, TensorType from pytensor.tensor.type import DenseTensorType, TensorType
...@@ -266,6 +264,7 @@ def local_elemwise_alloc(fgraph, node): ...@@ -266,6 +264,7 @@ def local_elemwise_alloc(fgraph, node):
introduces them as a canonicalization of `Alloc`'s with leading introduces them as a canonicalization of `Alloc`'s with leading
broadcastable dimensions. broadcastable dimensions.
""" """
# This is handled by local_alloc_unary
if len(node.inputs) == 1: if len(node.inputs) == 1:
return None return None
...@@ -465,14 +464,7 @@ def local_useless_alloc(fgraph, node): ...@@ -465,14 +464,7 @@ def local_useless_alloc(fgraph, node):
inp.type.dtype == output.type.dtype inp.type.dtype == output.type.dtype
and inp.type.broadcastable == output.type.broadcastable and inp.type.broadcastable == output.type.broadcastable
): ):
if inp.ndim == 0: return [inp]
return [inp]
else:
return [
Assert("Shapes must be equal")(
inp, at_all(eq(inp.shape, node.inputs[1:]))
)
]
@register_specialize @register_specialize
......
...@@ -272,21 +272,36 @@ class TestLocalCanonicalizeAlloc: ...@@ -272,21 +272,36 @@ class TestLocalCanonicalizeAlloc:
def setup_method(self): def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed()) self.rng = np.random.default_rng(utt.fetch_seed())
def test_inconsistent_shared(self): @pytest.mark.parametrize("shape_unsafe", (True, False))
def test_inconsistent_shared(self, shape_unsafe):
# These shapes don't match! # These shapes don't match!
x = shared(self.rng.standard_normal((3, 7))) x = shared(self.rng.standard_normal((3, 7)))
a = at.alloc(x, 6, 7) a = at.alloc(x, 6, 7)
assert a.owner and isinstance(a.owner.op, Alloc) assert a.owner and isinstance(a.owner.op, Alloc)
f = function([], a, mode=rewrite_mode) mode = rewrite_mode if shape_unsafe else rewrite_mode.excluding("shape_unsafe")
f = function([], a, mode=mode)
# The rewrite should then be applied, and remove Alloc has_alloc = any(
assert not any(isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()) isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()
assert any(isinstance(node.op, Assert) for node in f.maker.fgraph.toposort()) )
if shape_unsafe:
with pytest.raises(AssertionError): assert not has_alloc
f() # Error raised by SpecifyShape that is introduced due to static shape inference
with pytest.raises(
AssertionError,
match="SpecifyShape: dim 0 of input has shape 3, expected 6.",
):
f()
else:
assert has_alloc
# Error raised by Alloc Op
with pytest.raises(
ValueError,
match=r"could not broadcast input array from shape \(3,7\) into shape \(6,7\)",
):
f()
good_x_val = self.rng.standard_normal((6, 7)) good_x_val = self.rng.standard_normal((6, 7))
x.set_value(good_x_val) x.set_value(good_x_val)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论