提交 c86a717b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Simplify construction of aesara.tensor.basic_opt.local_elemwise_alloc

上级 91d149ab
......@@ -1470,8 +1470,9 @@ aesara.compile.mode.optdb.register(
aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), 10)
def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
def local_elemwise_alloc(fgraph, node):
@register_specialize("local_alloc_elemwise")
@local_optimizer([Elemwise])
def local_elemwise_alloc(fgraph, node):
"""
elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
-> elemwise(x, y.TensorType(BROADCAST CONDITION))
......@@ -1483,12 +1484,15 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
not to be optimized to have the same broadcast pattern as the
output.
We can change the alloc by a dimshuffle as the elemwise
already have the shape info. The dimshuffle will be faster
to exec.
We can change the `Alloc` by a `DimShuffle` as the `Elemwise` already have
the shape info. The `DimShuffle` will be faster to exec.
TODO: Global optimizer that lifts the assert to the beginning of the graph?
TODO: Optimize all inputs when possible -- currently when all inputs have
an `Alloc` all but one is optimized.
"""
if not isinstance(node.op, ElemwiseOP):
if not isinstance(node.op, Elemwise):
return False
if len(node.outputs) > 1:
......@@ -1513,17 +1517,17 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
def dimshuffled_alloc(i):
return (
isinstance(i.owner.op, DimShuffleOP)
isinstance(i.owner.op, DimShuffle)
and i.owner.inputs[0].owner
and isinstance(i.owner.inputs[0].owner.op, AllocOP)
and isinstance(i.owner.inputs[0].owner.op, Alloc)
)
# At least one input must have an owner that is either a AllocOP or a
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is
# At least one input must have an owner that is either a `Alloc` or a
# `DimShuffle` with an owner that is a `Alloc` -- otherwise there is
# nothing to optimize.
if not any(
[
i.owner and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
for i in node.inputs
]
):
......@@ -1533,32 +1537,28 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert_op_idx = -1
for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a AllocOP nor a DimShuffleOP of a
# AllocOP so that all allocs can be optimized.
# Prefer an input that is not a `Alloc` nor a `DimShuffle` of a
# `Alloc` so that all `Alloc`s can be optimized.
if not (
i.owner
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
):
assert_op_idx = idx
break
# It may be the case that only AllocOP and DimShuffleOP of AllocOP exist.
# It may be the case that only `Alloc` and `DimShuffle` of `Alloc` exist.
if assert_op_idx < 0:
# We want to optimize as many allocs as possible. When
# We want to optimize as many `Alloc`s as possible. When
# there is more than one then do all but one. number of
# inputs with alloc or dimshuffle alloc
# inputs with `Alloc` or `DimShuffle` `Alloc`
l2 = [
i
for i in node.inputs
if (
i.owner
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
)
if (i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)))
]
# If only 1 alloc or dimshuffle alloc, it is the one we
# will use for the shape. So no alloc would be removed.
# If only one `Alloc` or `DimShuffle` `Alloc`, it is the one we
# will use for the shape. So no `Alloc` would be removed.
if len(l2) > 1:
# l contains inputs with alloc or dimshuffle alloc
# One contains inputs with `Alloc` or `DimShuffle` `Alloc`
# only. Its length will always be at least one, as we
# checked that before
l = [
......@@ -1576,14 +1576,14 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
new_i = []
same_shape = fgraph.shape_feature.same_shape
for i in node.inputs:
# Remove alloc
# Remove `Alloc`
if (
i.owner
and isinstance(i.owner.op, AllocOP)
and isinstance(i.owner.op, Alloc)
and i.owner.inputs[0].type != i.owner.outputs[0].type
):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we
# will remove that alloc later
# when `i.owner.inputs[0].type == i.owner.outputs[0].type` we
# will remove that `Alloc` later
assert i.type.ndim == cmp_op.ndim
if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape
......@@ -1599,7 +1599,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert_op_in = assert_op(assert_op_in, *cond)
new_i.append(i.owner.inputs[0])
# Remove Alloc in DimShuffle
# Remove `Alloc` in `DimShuffle`
elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim
if config.experimental__local_alloc_elemwise_assert:
......@@ -1613,15 +1613,15 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert_op_in = assert_op(assert_op_in, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim:
# The alloc can add dimension to the value
# We add a dimshuffle to add them.
# We let later optimization merge the multiple dimshuffle
# The `Alloc` can add dimension to the value
# We add a `DimShuffle` to add them.
# We let later optimization merge the multiple `DimShuffle`
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)
# We need to keep the dimshuffle. It could swap axes or
# We need to keep the `DimShuffle`. It could swap axes or
# add dimensions anywhere.
r_i = i.owner.op(alloc_input)
......@@ -1638,18 +1638,6 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
copy_stack_trace(node.outputs, ret)
return ret
return local_elemwise_alloc
# TODO, global optimizer that lift the assert to the beginning of the graph.
# TODO, optimize all inputs when possible -- currently when all inputs have
# an alloc all but one is optimized.
local_elemwise_alloc = register_specialize(
local_optimizer([Elemwise])(local_elemwise_alloc_op(Elemwise, Alloc, DimShuffle)),
"local_alloc_elemwise",
)
@local_optimizer([Elemwise])
def local_fill_sink(fgraph, node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论