提交 1798404b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add shape assertion to local_useless_alloc

上级 dc7cd4c9
...@@ -74,6 +74,7 @@ from aesara.tensor.basic import ( ...@@ -74,6 +74,7 @@ from aesara.tensor.basic import (
from aesara.tensor.elemwise import DimShuffle, Elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.exceptions import NotScalarConstantError, ShapeError from aesara.tensor.exceptions import NotScalarConstantError, ShapeError
from aesara.tensor.extra_ops import broadcast_shape from aesara.tensor.extra_ops import broadcast_shape
from aesara.tensor.math import all as at_all
from aesara.tensor.math import eq from aesara.tensor.math import eq
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape, shape_padleft from aesara.tensor.shape import Reshape, Shape, Shape_i, shape, shape_padleft
from aesara.tensor.sort import TopKOp from aesara.tensor.sort import TopKOp
...@@ -1782,25 +1783,28 @@ def local_useless_fill(fgraph, node): ...@@ -1782,25 +1783,28 @@ def local_useless_fill(fgraph, node):
@register_stabilize @register_stabilize
@register_canonicalize @register_canonicalize
@register_useless @register_useless
@local_optimizer([alloc]) @local_optimizer([Alloc])
def local_useless_alloc(fgraph, node): def local_useless_alloc(fgraph, node):
""" """
If the input type is the same as the output type (dtype and broadcast) If the input type is the same as the output type (dtype and broadcast)
there is no change in the shape of the input. So this is just a simple copy there is no change in the shape of the input. So this is just a simple copy
of the input. This is not needed. of the input. This is not needed.
""" """
op = node.op if not isinstance(node.op, Alloc):
if not isinstance(op, Alloc):
return False return False
input = node.inputs[0] input = node.inputs[0]
output = node.outputs[0] output = node.outputs[0]
# Check if dtype and broadcast remain the same.
if input.type == output.type: if input.type == output.type:
# We don't need to copy over any stack traces here if input.ndim == 0:
return [input] return [input]
else:
return [
Assert("Shapes must be equal")(
input, at_all(eq(input.shape, node.inputs[1:]))
)
]
@register_specialize @register_specialize
......
...@@ -1434,62 +1434,77 @@ class TestLocalCanonicalizeAlloc: ...@@ -1434,62 +1434,77 @@ class TestLocalCanonicalizeAlloc:
def setup_method(self): def setup_method(self):
self.rng = np.random.default_rng(utt.fetch_seed()) self.rng = np.random.default_rng(utt.fetch_seed())
@config.change_flags(compute_test_value="off") def test_inconsistent_constant(self):
def test_basic(self): x = aet.as_tensor(self.rng.standard_normal((3, 7)))
a = aet.alloc(x, 6, 7)
assert a.owner and isinstance(a.owner.op, Alloc)
# with aesara.config.change_flags(optimizer_verbose=True):
with pytest.raises(AssertionError):
f = function([], a, mode=mode_opt)
x = aet.as_tensor(self.rng.standard_normal((6, 7)))
a = aet.alloc(x, 6, 7)
f = function([], a, mode=mode_opt)
# The optimization should then be applied, and remove Alloc
assert not any(
[isinstance(node.op, (Alloc, Assert)) for node in f.maker.fgraph.toposort()]
)
def test_inconsistent_shared(self):
# These shapes don't match!
x = shared(self.rng.standard_normal((3, 7))) x = shared(self.rng.standard_normal((3, 7)))
a = aet.alloc(x, 6, 7) a = aet.alloc(x, 6, 7)
# It is a bad idea to have aet.alloc return x directly,
# because the shape mismatch cannot be caught.
assert a.owner and isinstance(a.owner.op, Alloc) assert a.owner and isinstance(a.owner.op, Alloc)
f = function([], a, mode=mode_opt) f = function([], a, mode=mode_opt)
# The optimization should then be applied, and remove Alloc # The optimization should then be applied, and remove Alloc
assert [node.op for node in f.maker.fgraph.toposort()] == [deep_copy_op] assert not any(
[isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()]
)
assert any([isinstance(node.op, Assert) for node in f.maker.fgraph.toposort()])
# In DebugMode, the shape mismatch should be detected with pytest.raises(AssertionError):
if isinstance(mode_opt, DebugMode): f()
with pytest.raises(ValueError):
f
# No need to check_stack_trace as the optimization good_x_val = self.rng.standard_normal((6, 7))
# local_canonicalize_alloc only removes nodes. x.set_value(good_x_val)
def test_basic_1(self): assert np.array_equal(f(), good_x_val)
# Test that alloc never gets instantiated during optimization
mode = mode_opt.excluding("local_canonicalize_alloc")
def test_basic_fill(self):
x = matrix("x") x = matrix("x")
xx = aet.fill(x, x) y = aet.fill(x, x)
# The optimization 'locall_fill_to_alloc' should call aet.alloc, # The optimization 'locall_fill_to_alloc' should call aet.alloc,
# which should return x and not alloc(x, ...) # which should return x and not alloc(x, ...)
f = function([x], [xx], mode=mode)
op_classes = [node.op.__class__ for node in f.maker.fgraph.toposort()]
assert Alloc not in op_classes
# No need to check_stack_trace as the optimization
# local_canonicalize_alloc only removes nodes.
def test_basic_2(self):
# Test that alloc never gets instantiated during optimization
mode = mode_opt.excluding("local_canonicalize_alloc") mode = mode_opt.excluding("local_canonicalize_alloc")
f = function([x], [y], mode=mode)
assert not any(
[isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()]
)
def test_basic_tile(self):
x = matrix("x") x = matrix("x")
y = aet.tile(x, (1,) * 2) y = aet.tile(x, (1,) * 2)
mode = mode_opt.including("local_canonicalize_alloc")
f = function([x], [y], mode=mode) f = function([x], [y], mode=mode)
op_classes = [node.op.__class__ for node in f.maker.fgraph.toposort()] [node.op.__class__ for node in f.maker.fgraph.toposort()]
# We are supposed to test if tensr.Alloc is not in op_classes, assert not any(
# but since the proper proper optimization is not currently [isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()]
# implemented it will fail. Once the correct optimization is in place, )
# we have to change the following we should not see Alloc
# in op_classes and we have to change the assert.
assert Alloc in op_classes
# The correct opt removes nodes, no need for check_stack_trace
def test_useless_alloc_with_shape_one(self): def test_useless_alloc_with_shape_one(self):
"""
TODO FIXME: Remove/replace the string output comparisons.
"""
alloc_lift = out2in(local_canonicalize_alloc) alloc_lift = out2in(local_canonicalize_alloc)
x = shared(self.rng.standard_normal((2,))) x = shared(self.rng.standard_normal((2,)))
y = shared(self.rng.standard_normal()) y = shared(self.rng.standard_normal())
...@@ -2406,8 +2421,9 @@ class TestLocalUselessSwitch: ...@@ -2406,8 +2421,9 @@ class TestLocalUselessSwitch:
z = aet.switch(1, x, y) z = aet.switch(1, x, y)
f = function([x, y], z, mode=self.mode) f = function([x, y], z, mode=self.mode)
assert isinstance(f.maker.fgraph.outputs[0].owner.op, Elemwise) start_var = f.maker.fgraph.outputs[0].owner.inputs[0]
assert isinstance(f.maker.fgraph.outputs[0].owner.op.scalar_op, aes.basic.Cast) assert isinstance(start_var.owner.op, Elemwise)
assert isinstance(start_var.owner.op.scalar_op, aes.basic.Cast)
assert not any(node.op == aet.switch for node in f.maker.fgraph.toposort()) assert not any(node.op == aet.switch for node in f.maker.fgraph.toposort())
vx = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32") vx = np.array([[1, 2, 3], [4, 5, 6]], dtype="int32")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论