提交 0f1f5beb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Refactor local_canonicalize_alloc into local_alloc_sink_dimshuffle

上级 240827cf
......@@ -1809,19 +1809,9 @@ def local_useless_alloc(fgraph, node):
@register_specialize
@register_stabilize
@register_canonicalize
@local_optimizer([alloc])
def local_canonicalize_alloc(fgraph, node):
"""If the input type is the same as the output type (dtype and broadcast)
there is no change in the shape of the input. So this is just a simple copy
of the input. This is not needed. (as local_useless_alloc)
Also, it will canonicalize alloc by creating Dimshuffle after the
alloc to introduce the dimensions of constant size 1.
See https://github.com/Theano/Theano/issues/4072 to know why this
is needed.
"""
@local_optimizer([Alloc])
def local_alloc_sink_dimshuffle(fgraph, node):
r"""Convert broadcastable leading dimensions in an `Alloc` to `DimShuffle`\s."""
op = node.op
if not isinstance(op, Alloc):
return False
......@@ -1829,22 +1819,7 @@ def local_canonicalize_alloc(fgraph, node):
inp = node.inputs[0]
output = node.outputs[0]
# Check if dtype and broadcast remain the same.
if (
inp.type.dtype == output.type.dtype
and inp.type.broadcastable == output.type.broadcastable
):
# We don't need to copy over any stack traces here
return [inp]
# Allow local_merge_alloc to do its work first
clients = fgraph.clients[output]
for client, i in clients:
if client != "output" and isinstance(client.op, Alloc):
return
# Check if alloc adds a broadcastable dimension with shape 1.
output_shape = node.inputs[1:]
num_dims_with_size_1_added_to_left = 0
for i in range(len(output_shape) - inp.ndim):
......@@ -1852,6 +1827,7 @@ def local_canonicalize_alloc(fgraph, node):
num_dims_with_size_1_added_to_left += 1
else:
break
new_output_shape = output_shape[num_dims_with_size_1_added_to_left:]
if num_dims_with_size_1_added_to_left > 0 and len(new_output_shape) >= inp.ndim:
if (
......
......@@ -40,7 +40,7 @@ from aesara.tensor.basic_opt import (
ShapeFeature,
apply_rebroadcast_opt,
assert_op,
local_canonicalize_alloc,
local_alloc_sink_dimshuffle,
local_dimshuffle_lift,
local_merge_alloc,
local_reshape_to_dimshuffle,
......@@ -1423,8 +1423,7 @@ class TestLocalCanonicalizeAlloc:
# The optimization 'locall_fill_to_alloc' should call at.alloc,
# which should return x and not alloc(x, ...)
mode = mode_opt.excluding("local_canonicalize_alloc")
f = function([x], [y], mode=mode)
f = function([x], [y], mode=mode_opt.including("local_fill_to_alloc"))
assert not any(
[isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()]
)
......@@ -1433,9 +1432,12 @@ class TestLocalCanonicalizeAlloc:
x = matrix("x")
y = at.tile(x, (1,) * 2)
mode = mode_opt.including("local_canonicalize_alloc")
mode = mode_opt.including(
"local_dimshuffle_lift",
"local_useless_dimshuffle_in_reshape",
"local_alloc_sink_dimshuffle",
)
f = function([x], [y], mode=mode)
[node.op.__class__ for node in f.maker.fgraph.toposort()]
assert not any(
[isinstance(node.op, Alloc) for node in f.maker.fgraph.toposort()]
......@@ -1454,7 +1456,7 @@ class TestLocalCanonicalizeAlloc:
g = FunctionGraph(outputs=[x])
assert any(isinstance(node.op, Alloc) for node in g.toposort())
alloc_lift = out2in(local_canonicalize_alloc)
alloc_lift = out2in(local_alloc_sink_dimshuffle)
alloc_lift.optimize(g)
if has_alloc:
......@@ -3217,7 +3219,7 @@ def test_local_Unique_Alloc_lift(
# The remaining exclusions simply allow us to perform the check below that
# makes sure the original `Alloc` is present in our reference (sub)graph.
opt_mode = default_mode.excluding(
"local_useless_alloc", "local_canonicalize_alloc", "local_Unique_Alloc_lift"
"local_useless_alloc", "local_alloc_sink_dimshuffle", "local_Unique_Alloc_lift"
)
y_fn = function([x], [y, y_opt], mode=opt_mode)
# Make sure that the original `Alloc` is used to compute the reference `y`
......
......@@ -1860,7 +1860,7 @@ class TestLocalElemwiseAlloc:
# Exclude local_useless_alloc, since it does not introduce
# assert in all the same cases.
self.fast_run_mode = self.fast_run_mode.excluding(
"local_useless_alloc", "local_canonicalize_alloc"
"local_useless_alloc", "local_alloc_sink_dimshuffle"
)
# No optimization on alloc
func = function(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论