提交 c86a717b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Simplify construction of aesara.tensor.basic_opt.local_elemwise_alloc

上级 91d149ab
...@@ -1470,8 +1470,9 @@ aesara.compile.mode.optdb.register( ...@@ -1470,8 +1470,9 @@ aesara.compile.mode.optdb.register(
aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), 10) aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), 10)
def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): @register_specialize("local_alloc_elemwise")
def local_elemwise_alloc(fgraph, node): @local_optimizer([Elemwise])
def local_elemwise_alloc(fgraph, node):
""" """
elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION)) elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
-> elemwise(x, y.TensorType(BROADCAST CONDITION)) -> elemwise(x, y.TensorType(BROADCAST CONDITION))
...@@ -1483,12 +1484,15 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1483,12 +1484,15 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
not to be optimized to have the same broadcast pattern as the not to be optimized to have the same broadcast pattern as the
output. output.
We can change the alloc by a dimshuffle as the elemwise We can change the `Alloc` by a `DimShuffle` as the `Elemwise` already have
already have the shape info. The dimshuffle will be faster the shape info. The `DimShuffle` will be faster to exec.
to exec.
TODO: Global optimizer that lifts the assert to the beginning of the graph?
TODO: Optimize all inputs when possible -- currently when all inputs have
an `Alloc` all but one is optimized.
""" """
if not isinstance(node.op, ElemwiseOP): if not isinstance(node.op, Elemwise):
return False return False
if len(node.outputs) > 1: if len(node.outputs) > 1:
...@@ -1513,17 +1517,17 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1513,17 +1517,17 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
def dimshuffled_alloc(i): def dimshuffled_alloc(i):
return ( return (
isinstance(i.owner.op, DimShuffleOP) isinstance(i.owner.op, DimShuffle)
and i.owner.inputs[0].owner and i.owner.inputs[0].owner
and isinstance(i.owner.inputs[0].owner.op, AllocOP) and isinstance(i.owner.inputs[0].owner.op, Alloc)
) )
# At least one input must have an owner that is either a AllocOP or a # At least one input must have an owner that is either a `Alloc` or a
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is # `DimShuffle` with an owner that is a `Alloc` -- otherwise there is
# nothing to optimize. # nothing to optimize.
if not any( if not any(
[ [
i.owner and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i)) i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
for i in node.inputs for i in node.inputs
] ]
): ):
...@@ -1533,32 +1537,28 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1533,32 +1537,28 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert_op_idx = -1 assert_op_idx = -1
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable: if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a AllocOP nor a DimShuffleOP of a # Prefer an input that is not a `Alloc` nor a `DimShuffle` of a
# AllocOP so that all allocs can be optimized. # `Alloc` so that all `Alloc`s can be optimized.
if not ( if not (
i.owner i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
): ):
assert_op_idx = idx assert_op_idx = idx
break break
# It may be the case that only AllocOP and DimShuffleOP of AllocOP exist. # It may be the case that only `Alloc` and `DimShuffle` of `Alloc` exist.
if assert_op_idx < 0: if assert_op_idx < 0:
# We want to optimize as many allocs as possible. When # We want to optimize as many `Alloc`s as possible. When
# there is more than one then do all but one. number of # there is more than one then do all but one. number of
# inputs with alloc or dimshuffle alloc # inputs with `Alloc` or `DimShuffle` `Alloc`
l2 = [ l2 = [
i i
for i in node.inputs for i in node.inputs
if ( if (i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)))
i.owner
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
)
] ]
# If only 1 alloc or dimshuffle alloc, it is the one we # If only one `Alloc` or `DimShuffle` `Alloc`, it is the one we
# will use for the shape. So no alloc would be removed. # will use for the shape. So no `Alloc` would be removed.
if len(l2) > 1: if len(l2) > 1:
# l contains inputs with alloc or dimshuffle alloc # One contains inputs with `Alloc` or `DimShuffle` `Alloc`
# only. Its length will always be at least one, as we # only. Its length will always be at least one, as we
# checked that before # checked that before
l = [ l = [
...@@ -1576,14 +1576,14 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1576,14 +1576,14 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
new_i = [] new_i = []
same_shape = fgraph.shape_feature.same_shape same_shape = fgraph.shape_feature.same_shape
for i in node.inputs: for i in node.inputs:
# Remove alloc # Remove `Alloc`
if ( if (
i.owner i.owner
and isinstance(i.owner.op, AllocOP) and isinstance(i.owner.op, Alloc)
and i.owner.inputs[0].type != i.owner.outputs[0].type and i.owner.inputs[0].type != i.owner.outputs[0].type
): ):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we # when `i.owner.inputs[0].type == i.owner.outputs[0].type` we
# will remove that alloc later # will remove that `Alloc` later
assert i.type.ndim == cmp_op.ndim assert i.type.ndim == cmp_op.ndim
if config.experimental__local_alloc_elemwise_assert: if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape get_shape = fgraph.shape_feature.get_shape
...@@ -1599,7 +1599,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1599,7 +1599,7 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert_op_in = assert_op(assert_op_in, *cond) assert_op_in = assert_op(assert_op_in, *cond)
new_i.append(i.owner.inputs[0]) new_i.append(i.owner.inputs[0])
# Remove Alloc in DimShuffle # Remove `Alloc` in `DimShuffle`
elif i.owner and dimshuffled_alloc(i): elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim assert i.type.ndim == cmp_op.type.ndim
if config.experimental__local_alloc_elemwise_assert: if config.experimental__local_alloc_elemwise_assert:
...@@ -1613,15 +1613,15 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1613,15 +1613,15 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
assert_op_in = assert_op(assert_op_in, *assert_cond) assert_op_in = assert_op(assert_op_in, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0] alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim: if alloc_input.ndim != i.owner.inputs[0].ndim:
# The alloc can add dimension to the value # The `Alloc` can add dimension to the value
# We add a dimshuffle to add them. # We add a `DimShuffle` to add them.
# We let later optimization merge the multiple dimshuffle # We let later optimization merge the multiple `DimShuffle`
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle( alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim)) ["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
) )
# We need to keep the dimshuffle. It could swap axes or # We need to keep the `DimShuffle`. It could swap axes or
# add dimensions anywhere. # add dimensions anywhere.
r_i = i.owner.op(alloc_input) r_i = i.owner.op(alloc_input)
...@@ -1638,18 +1638,6 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1638,18 +1638,6 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
copy_stack_trace(node.outputs, ret) copy_stack_trace(node.outputs, ret)
return ret return ret
return local_elemwise_alloc
# TODO, global optimizer that lift the assert to the beginning of the graph.
# TODO, optimize all inputs when possible -- currently when all inputs have
# an alloc all but one is optimized.
local_elemwise_alloc = register_specialize(
local_optimizer([Elemwise])(local_elemwise_alloc_op(Elemwise, Alloc, DimShuffle)),
"local_alloc_elemwise",
)
@local_optimizer([Elemwise]) @local_optimizer([Elemwise])
def local_fill_sink(fgraph, node): def local_fill_sink(fgraph, node):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论