提交 5f809cfe authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify rewrites by assuming Elemwise / Alloc shapes are correct

上级 2c4a3e7b
......@@ -1013,7 +1013,7 @@ class TestLocalUselessSwitch:
z = at.switch(1, x, y)
f = function([x, y], z, mode=self.mode)
start_var = f.maker.fgraph.outputs[0].owner.inputs[0]
start_var = f.maker.fgraph.outputs[0]
assert isinstance(start_var.owner.op, Elemwise)
assert isinstance(start_var.owner.op.scalar_op, aes.basic.Cast)
assert not any(node.op == at.switch for node in f.maker.fgraph.toposort())
......@@ -1698,45 +1698,50 @@ class TestLocalElemwiseAlloc:
)
@pytest.mark.parametrize(
"expr, x_shape, y_shape",
"expr, x_shape, y_shape, needs_alloc",
[
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 2), (3, 2)),
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 1), (1, 1)),
(lambda x, y: at.mul(x, at.alloc(y, 2, 3)), (1, 3), (2, 3)),
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 2), (3, 2), True),
(lambda x, y: at.mul(at.alloc(1, *y.shape), x), (1, 1), (1, 1), False),
(lambda x, y: at.mul(x, at.alloc(y, 2, 3)), (1, 3), (2, 3), False),
(
lambda x, y: at.mul(
at.alloc(x, 3).dimshuffle("x", 0), y.dimshuffle("x", "x")
),
(),
(),
True,
),
(lambda x, y: at.mul(y, at.alloc(1, x)), (), ()),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)),
(lambda x, y: at.mul(y, at.alloc(1, x)), (), (), True),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1), False),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2), False),
(
lambda x, y: at.mul(at.alloc(x, 15, 1), at.alloc(y, 15, 1)),
(15, 1),
(15, 1),
False,
),
(
lambda x, y: at.mul(at.alloc(x, 15, 2), at.alloc(y, 15, 2)),
(15, 2),
(15, 2),
False,
),
(
lambda x, y: at.mul(at.alloc(x, 15, 2).dimshuffle(1, 0), y),
(15, 2),
(2, 15),
False,
),
(lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2)),
(lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2), False),
(
lambda x, y: at.mul(at.alloc(x, 1, 15, 2).dimshuffle(0, 2, 1), y),
(15, 2),
(2, 15),
False,
),
],
)
def test_basic(self, expr, x_shape, y_shape):
def test_basic(self, expr, x_shape, y_shape, needs_alloc):
x = at.tensor(
dtype="int64", shape=(1 if val == 1 else None for val in x_shape), name="x"
)
......@@ -1752,10 +1757,16 @@ class TestLocalElemwiseAlloc:
on_unused_input="ignore",
)
assert not any(
isinstance(node.op, Alloc) for node in z_opt.maker.fgraph.toposort()
)
nodes = z_opt.maker.fgraph.toposort()
if needs_alloc:
# When the final result needs an Alloc, this should be the last node
# x = scalar; y = vector; mul(x, ones_like(y)) -> alloc(x, y.shape)
assert isinstance(nodes[-1].op, Alloc)
nodes = nodes[:-1]
assert not any(isinstance(node.op, Alloc) for node in nodes)
# Check results are the same without the optimization
z_no_opt = pytensor.function(
[x, y],
z,
......@@ -1799,7 +1810,7 @@ class TestLocalElemwiseAlloc:
[self.vec, self.mat], self.alloc_wo_dep + self.mat, mode=self.fast_run_mode
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 2, Assert)
self.verify_op_count(func, 1, SpecifyShape)
func = function(
[self.vec, self.mat],
......@@ -1807,7 +1818,7 @@ class TestLocalElemwiseAlloc:
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 1, Assert)
self.verify_op_count(func, 1, SpecifyShape)
# No optimization on alloc without assert
func = function(
......@@ -1839,7 +1850,10 @@ class TestLocalElemwiseAlloc:
self.alloc_w_dep_broad2 + self.mat,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
# This graph requires one outer Alloc and an Assert
# To make sure `mat` is square since we end up doing
# broadcast_to(x, mat[..., None].shape) + mat[None, ...]
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 1, Assert)
def test_remove_alloc_w_dimshuffle(self):
......@@ -1851,16 +1865,13 @@ class TestLocalElemwiseAlloc:
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 0, Assert)
# TODO FIXME: The `BroadcastTo` shapes should use the constants
# provided by the first/`Alloc` term, and not the unknown values from
# the `tens` term.
func = function(
[self.vec, self.tens],
self.alloc_wo_dep.dimshuffle(0, 1, "x") + self.tens,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 2, Assert)
self.verify_op_count(func, 1, SpecifyShape)
func = function(
[self.vec, self.tens],
......@@ -1888,16 +1899,13 @@ class TestLocalElemwiseAlloc:
self.verify_op_count(func, 2, Alloc)
self.verify_op_count(func, 0, Assert)
# Optimization on dimshuffle with assert
# TODO: When we support static shape constraints like `shape[i] != 1`,
# reproduce this with such a constraint on `mat` and make sure the
# `BroadcastTo` is removed.
func = function(
[self.vec, self.mat],
self.tv_wo_dep + self.tm_wo_dep,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
# It still needs an outer alloc to broadcast final shape
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 0, Assert)
# No optimization on dimshuffle without assert
......@@ -1909,25 +1917,24 @@ class TestLocalElemwiseAlloc:
self.verify_op_count(func, 2, Alloc)
self.verify_op_count(func, 0, Assert)
# Optimization on dimshuffle without assert
func = function(
[self.vec, self.mat, self.s],
self.tv_w_dep + self.tm_w_dep,
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
# The second assert is from the shape check...
self.verify_op_count(func, 2, Assert)
# It still needs an outer alloc to broadcast final shape
self.verify_op_count(func, 1, Alloc)
self.verify_op_count(func, 0, Assert)
def test_misc(self):
x = row(dtype=self.dtype)
y = tensor(dtype=self.dtype, shape=(None, None, 1))
x = row("x", dtype=self.dtype)
y = tensor("y", dtype=self.dtype, shape=(None, None, 1))
out = at.alloc(x, 5, 5).dimshuffle(0, 1, "x") + y
func = function([y, x], out, mode=self.fast_run_mode)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 2, Assert)
self.verify_op_count(func, 1, SpecifyShape)
y_val = np.random.random((5, 5, 1)).astype(self.dtype)
x_val = np.random.random((1, 5)).astype(self.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论