提交 7c1558ad authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Refactor and fix local_elemwise_alloc

The rewrite would sometimes return a new graph identical to the original, resulting in divergence.
上级 0f1f5beb
...@@ -1486,28 +1486,34 @@ aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), 10) ...@@ -1486,28 +1486,34 @@ aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), 10)
@register_specialize("local_alloc_elemwise") @register_specialize("local_alloc_elemwise")
@local_optimizer([Elemwise]) @local_optimizer([Elemwise])
def local_elemwise_alloc(fgraph, node): def local_elemwise_alloc(fgraph, node):
""" r"""Remove unnecessary `Alloc`\s that occur as inputs of `Elemwise` `Op`\s.
elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
-> elemwise(x, y.TensorType(BROADCAST CONDITION))
elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION))
-> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION))
BROADCAST CONDITION: the condition is that the one input that are `Alloc`\s are effectively a type of `Elemwise` operation
not to be optimized to have the same broadcast pattern as the (e.g. ``Elemwise{second}(y, x)`` is the same as ``Alloc(x, *y.shape)``), so
output. this rewrite uses that fact to reduce `Elemwise`\s on `Alloc`\s to
`Elemwise`\s of the `Alloc`\s first/value input (i.e. the value it
broadcasts).
We can change the `Alloc` by a `DimShuffle` as the `Elemwise` already have In other words, this rewrite causes `Elemwise` `Op`\s to "absorb" redundant
the shape info. The `DimShuffle` will be faster to exec. `Alloc`\s.
TODO: Global optimizer that lifts the assert to the beginning of the graph? The rewrite essentially performs the following replacement:
TODO: Optimize all inputs when possible -- currently when all inputs have ``Elemwise{op}(..., Alloc(x, s), ..., y, ...) -> Elemwise{op}(..., x, ..., y, ...)``,
an `Alloc` all but one is optimized. when ``y.shape`` for some input ``y`` (or the combined shapes of the
non-`Alloc`\s) is sufficient to maintain the same/correct output shape.
In it's current form, it also explicitly accounts for `DimShuffle`\s of
`Alloc`\s. This is largely due to `local_alloc_sink_dimshuffle`, which
introduces them as a canonicalization of `Alloc`'s with leading
broadcastable dimensions.
""" """
if not isinstance(node.op, Elemwise): if not isinstance(node.op, Elemwise):
return False return False
# Rewrite is only applicable when there are at least two inputs
if len(node.inputs) == 1:
return None
if len(node.outputs) > 1: if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern # Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true. # This is a supposition that I'm not sure is always true.
...@@ -1546,8 +1552,9 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1546,8 +1552,9 @@ def local_elemwise_alloc(fgraph, node):
): ):
return False return False
# Search for input that we can use as a baseline for the dimensions. # Search for a non `Alloc` or `DimShuffle` of `Alloc` input that we can use as a
assert_op_idx = -1 # baseline for the dimensions.
assert_op_idx = None
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable: if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a `Alloc` nor a `DimShuffle` of a # Prefer an input that is not a `Alloc` nor a `DimShuffle` of a
...@@ -1558,31 +1565,14 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1558,31 +1565,14 @@ def local_elemwise_alloc(fgraph, node):
assert_op_idx = idx assert_op_idx = idx
break break
# It may be the case that only `Alloc` and `DimShuffle` of `Alloc` exist. # If only `Alloc` and `DimShuffle` of `Alloc` exist, we pick the first suitable one
if assert_op_idx < 0: if assert_op_idx is None:
# We want to optimize as many `Alloc`s as possible. When for idx, i in enumerate(node.inputs):
# there is more than one then do all but one. number of if (i.type.broadcastable == node.outputs[0].type.broadcastable) and (
# inputs with `Alloc` or `DimShuffle` `Alloc` i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
l2 = [ ):
i assert_op_idx = idx
for i in node.inputs break
if (i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)))
]
# If only one `Alloc` or `DimShuffle` `Alloc`, it is the one we
# will use for the shape. So no `Alloc` would be removed.
if len(l2) > 1:
# One contains inputs with `Alloc` or `DimShuffle` `Alloc`
# only. Its length will always be at least one, as we
# checked that before
l = [
idx
for idx, i in enumerate(node.inputs)
if i.broadcastable == node.outputs[0].broadcastable
]
assert_op_idx = l[0] # The first one is as good as any to use.
else:
# Nothing would be optimized!
return False
assert_op_in = node.inputs[assert_op_idx] assert_op_in = node.inputs[assert_op_idx]
cmp_op = assert_op_in cmp_op = assert_op_in
...@@ -1590,13 +1580,7 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1590,13 +1580,7 @@ def local_elemwise_alloc(fgraph, node):
same_shape = fgraph.shape_feature.same_shape same_shape = fgraph.shape_feature.same_shape
for i in node.inputs: for i in node.inputs:
# Remove `Alloc` # Remove `Alloc`
if ( if i.owner and isinstance(i.owner.op, Alloc):
i.owner
and isinstance(i.owner.op, Alloc)
and not i.owner.inputs[0].type.is_super(i.owner.outputs[0].type)
):
# when `i.owner.inputs[0].type.is_super(i.owner.outputs[0].type)` we
# will remove that `Alloc` later
assert i.type.ndim == cmp_op.ndim assert i.type.ndim == cmp_op.ndim
if config.experimental__local_alloc_elemwise_assert: if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape get_shape = fgraph.shape_feature.get_shape
...@@ -1610,7 +1594,16 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1610,7 +1594,16 @@ def local_elemwise_alloc(fgraph, node):
cond.append(eq(i_shp, cmp_shp)) cond.append(eq(i_shp, cmp_shp))
if cond: if cond:
assert_op_in = assert_op(assert_op_in, *cond) assert_op_in = assert_op(assert_op_in, *cond)
new_i.append(i.owner.inputs[0]) alloc_input = i.owner.inputs[0]
if alloc_input.ndim != i.ndim:
# The `Alloc` can add dimensions to the value.
# We replace those cases with a `DimShuffle` here.
nb_dim_to_add = i.ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)
copy_stack_trace(i, alloc_input)
new_i.append(alloc_input)
# Remove `Alloc` in `DimShuffle` # Remove `Alloc` in `DimShuffle`
elif i.owner and dimshuffled_alloc(i): elif i.owner and dimshuffled_alloc(i):
...@@ -1626,28 +1619,30 @@ def local_elemwise_alloc(fgraph, node): ...@@ -1626,28 +1619,30 @@ def local_elemwise_alloc(fgraph, node):
assert_op_in = assert_op(assert_op_in, *assert_cond) assert_op_in = assert_op(assert_op_in, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0] alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim: if alloc_input.ndim != i.owner.inputs[0].ndim:
# The `Alloc` can add dimension to the value # The `Alloc` can add dimensions to the value.
# We add a `DimShuffle` to add them. # We replace those cases with a `DimShuffle` here.
# We let later optimization merge the multiple `DimShuffle` # We let later optimizations merge the nested `DimShuffle`s
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle( alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim)) ["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
) )
# We need to keep the `DimShuffle`. It could swap axes or # We need to keep the old `DimShuffle`. It could swap axes or
# add dimensions anywhere. # add dimensions anywhere.
r_i = i.owner.op(alloc_input) r_i = i.owner.op(alloc_input)
# Copy stack trace from i to new_i
copy_stack_trace(i, r_i) copy_stack_trace(i, r_i)
new_i.append(r_i) new_i.append(r_i)
else: else:
new_i.append(i) new_i.append(i)
new_i[assert_op_idx] = assert_op_in new_i[assert_op_idx] = assert_op_in
ret = node.op(*new_i, return_list=True) # If this assert is triggered, it means we are recreating an equivalent graph
# which would result in a cyclical merge optimization.
if all(new is old for new, old in zip(new_i, node.inputs)):
return
# Copy over stack trace from previous outputs to new outputs. ret = node.op(*new_i, return_list=True)
copy_stack_trace(node.outputs, ret) copy_stack_trace(node.outputs, ret)
return ret return ret
......
...@@ -3507,3 +3507,65 @@ def test_Shape_i_canonicalize(): ...@@ -3507,3 +3507,65 @@ def test_Shape_i_canonicalize():
assert isinstance(y_opt.owner.op, Shape_i) assert isinstance(y_opt.owner.op, Shape_i)
assert y_opt.owner.op.i == 0 assert y_opt.owner.op.i == 0
assert y_opt.owner.inputs[0] == x assert y_opt.owner.inputs[0] == x
@pytest.mark.parametrize(
"expr, x_shape, y_shape",
[
pytest.param(
lambda x, y: at.mul(y, at.alloc(1, x)),
(),
(),
marks=pytest.mark.xfail(reason="Not implemented"),
),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)),
(lambda x, y: at.mul(at.alloc(x, 15, 1), at.alloc(y, 15, 1)), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), at.alloc(y, 15, 2)), (15, 2), (15, 2)),
(lambda x, y: at.mul(at.alloc(x, 15, 2).dimshuffle(1, 0), y), (15, 2), (2, 15)),
(lambda x, y: at.mul(at.alloc(x, 1, 15, 2), y), (15, 2), (15, 2)),
(
lambda x, y: at.mul(at.alloc(x, 1, 15, 2).dimshuffle(0, 2, 1), y),
(15, 2),
(2, 15),
),
],
)
def test_local_elemwise_alloc(expr, x_shape, y_shape):
x = at.tensor("int64", (False,) * len(x_shape))
y = at.tensor("int64", (False,) * len(y_shape))
z = expr(x, y)
z_opt = aesara.function(
[x, y],
z,
mode=get_default_mode().including("local_elemwise_alloc"),
on_unused_input="ignore",
)
assert not any(isinstance(node.op, Alloc) for node in z_opt.maker.fgraph.toposort())
z_no_opt = aesara.function(
[x, y],
z,
mode=get_default_mode().excluding("local_elemwise_alloc"),
on_unused_input="ignore",
)
x_val = np.arange(np.prod(x_shape), dtype=np.int64).reshape(x_shape)
y_val = np.arange(np.prod(y_shape), dtype=np.int64).reshape(y_shape)
res = z_opt(x_val, y_val)
exp_res = z_no_opt(x_val, y_val)
assert np.array_equal(res, exp_res)
def test_local_elemwise_alloc_single_input():
# Test that rewrite is not triggered when there is only one Alloc in an Elemwise
x = at.matrix("x")
z = at.exp(at.alloc(x, 15, 1))
z_fg = FunctionGraph(outputs=[z], copy_inputs=False, features=[ShapeFeature()])
z_opt_fg = optimize_graph(z_fg, clone=False, include=["local_elemwise_alloc"])
assert any(isinstance(node.op, Alloc) for node in z_opt_fg.apply_nodes)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论