提交 c86a717b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Simplify construction of aesara.tensor.basic_opt.local_elemwise_alloc

上级 91d149ab
...@@ -1470,185 +1470,173 @@ aesara.compile.mode.optdb.register( ...@@ -1470,185 +1470,173 @@ aesara.compile.mode.optdb.register(
aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), 10) aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), 10)
def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): @register_specialize("local_alloc_elemwise")
def local_elemwise_alloc(fgraph, node): @local_optimizer([Elemwise])
""" def local_elemwise_alloc(fgraph, node):
elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION)) """
-> elemwise(x, y.TensorType(BROADCAST CONDITION)) elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
-> elemwise(x, y.TensorType(BROADCAST CONDITION))
elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION)) elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION))
-> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION)) -> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION))
BROADCAST CONDITION: the condition is that the one input that are BROADCAST CONDITION: the condition is that the one input that are
not to be optimized to have the same broadcast pattern as the not to be optimized to have the same broadcast pattern as the
output. output.
We can change the alloc by a dimshuffle as the elemwise We can change the `Alloc` by a `DimShuffle` as the `Elemwise` already have
already have the shape info. The dimshuffle will be faster the shape info. The `DimShuffle` will be faster to exec.
to exec.
""" TODO: Global optimizer that lifts the assert to the beginning of the graph?
if not isinstance(node.op, ElemwiseOP): TODO: Optimize all inputs when possible -- currently when all inputs have
return False an `Alloc` all but one is optimized.
if len(node.outputs) > 1: """
# Ensure all outputs have the same broadcast pattern if not isinstance(node.op, Elemwise):
# This is a supposition that I'm not sure is always true. return False
assert all(
[
o.type.broadcastable == node.outputs[0].type.broadcastable
for o in node.outputs[1:]
]
)
# The broadcast pattern of the output must match the broadcast if len(node.outputs) > 1:
# pattern of at least one of the inputs. # Ensure all outputs have the same broadcast pattern
if not any( # This is a supposition that I'm not sure is always true.
assert all(
[ [
i.type.broadcastable == node.outputs[0].type.broadcastable o.type.broadcastable == node.outputs[0].type.broadcastable
for i in node.inputs for o in node.outputs[1:]
] ]
): )
return False
def dimshuffled_alloc(i):
return (
isinstance(i.owner.op, DimShuffleOP)
and i.owner.inputs[0].owner
and isinstance(i.owner.inputs[0].owner.op, AllocOP)
)
# At least one input must have an owner that is either a AllocOP or a # The broadcast pattern of the output must match the broadcast
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is # pattern of at least one of the inputs.
# nothing to optimize. if not any(
if not any( [
[ i.type.broadcastable == node.outputs[0].type.broadcastable
i.owner and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i)) for i in node.inputs
for i in node.inputs ]
] ):
): return False
return False
# Search for input that we can use as a baseline for the dimensions. def dimshuffled_alloc(i):
assert_op_idx = -1 return (
for idx, i in enumerate(node.inputs): isinstance(i.owner.op, DimShuffle)
if i.type.broadcastable == node.outputs[0].type.broadcastable: and i.owner.inputs[0].owner
# Prefer an input that is not a AllocOP nor a DimShuffleOP of a and isinstance(i.owner.inputs[0].owner.op, Alloc)
# AllocOP so that all allocs can be optimized. )
if not (
i.owner
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
):
assert_op_idx = idx
break
# It may be the case that only AllocOP and DimShuffleOP of AllocOP exist. # At least one input must have an owner that is either a `Alloc` or a
if assert_op_idx < 0: # `DimShuffle` with an owner that is a `Alloc` -- otherwise there is
# We want to optimize as many allocs as possible. When # nothing to optimize.
# there is more than one then do all but one. number of if not any(
# inputs with alloc or dimshuffle alloc [
l2 = [ i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
i for i in node.inputs
for i in node.inputs ]
if ( ):
i.owner return False
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
)
]
# If only 1 alloc or dimshuffle alloc, it is the one we
# will use for the shape. So no alloc would be removed.
if len(l2) > 1:
# l contains inputs with alloc or dimshuffle alloc
# only. Its length will always be at least one, as we
# checked that before
l = [
idx
for idx, i in enumerate(node.inputs)
if i.broadcastable == node.outputs[0].broadcastable
]
assert_op_idx = l[0] # The first one is as good as any to use.
else:
# Nothing would be optimized!
return False
assert_op_in = node.inputs[assert_op_idx] # Search for input that we can use as a baseline for the dimensions.
cmp_op = assert_op_in assert_op_idx = -1
new_i = [] for idx, i in enumerate(node.inputs):
same_shape = fgraph.shape_feature.same_shape if i.type.broadcastable == node.outputs[0].type.broadcastable:
for i in node.inputs: # Prefer an input that is not a `Alloc` nor a `DimShuffle` of a
# Remove alloc # `Alloc` so that all `Alloc`s can be optimized.
if ( if not (
i.owner i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
and isinstance(i.owner.op, AllocOP)
and i.owner.inputs[0].type != i.owner.outputs[0].type
): ):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we assert_op_idx = idx
# will remove that alloc later break
assert i.type.ndim == cmp_op.ndim
if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape
cond = []
for idx in range(i.type.ndim):
if not i.type.broadcastable[idx] and not same_shape(
i, cmp_op, idx, idx
):
i_shp = get_shape(i, idx)
cmp_shp = get_shape(cmp_op, idx)
cond.append(eq(i_shp, cmp_shp))
if cond:
assert_op_in = assert_op(assert_op_in, *cond)
new_i.append(i.owner.inputs[0])
# Remove Alloc in DimShuffle
elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim
if config.experimental__local_alloc_elemwise_assert:
assert_cond = [
eq(i.shape[idx], cmp_op.shape[idx])
for idx in range(i.type.ndim)
if not i.type.broadcastable[idx]
and not same_shape(i, cmp_op, idx, idx)
]
if assert_cond:
assert_op_in = assert_op(assert_op_in, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim:
# The alloc can add dimension to the value
# We add a dimshuffle to add them.
# We let later optimization merge the multiple dimshuffle
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)
# We need to keep the dimshuffle. It could swap axes or
# add dimensions anywhere.
r_i = i.owner.op(alloc_input)
# Copy stack trace from i to new_i
copy_stack_trace(i, r_i)
new_i.append(r_i)
else:
new_i.append(i)
new_i[assert_op_idx] = assert_op_in
ret = node.op(*new_i, return_list=True) # It may be the case that only `Alloc` and `DimShuffle` of `Alloc` exist.
if assert_op_idx < 0:
# We want to optimize as many `Alloc`s as possible. When
# there is more than one then do all but one. number of
# inputs with `Alloc` or `DimShuffle` `Alloc`
l2 = [
i
for i in node.inputs
if (i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)))
]
# If only one `Alloc` or `DimShuffle` `Alloc`, it is the one we
# will use for the shape. So no `Alloc` would be removed.
if len(l2) > 1:
# One contains inputs with `Alloc` or `DimShuffle` `Alloc`
# only. Its length will always be at least one, as we
# checked that before
l = [
idx
for idx, i in enumerate(node.inputs)
if i.broadcastable == node.outputs[0].broadcastable
]
assert_op_idx = l[0] # The first one is as good as any to use.
else:
# Nothing would be optimized!
return False
# Copy over stack trace from previous outputs to new outputs. assert_op_in = node.inputs[assert_op_idx]
copy_stack_trace(node.outputs, ret) cmp_op = assert_op_in
return ret new_i = []
same_shape = fgraph.shape_feature.same_shape
for i in node.inputs:
# Remove `Alloc`
if (
i.owner
and isinstance(i.owner.op, Alloc)
and i.owner.inputs[0].type != i.owner.outputs[0].type
):
# when `i.owner.inputs[0].type == i.owner.outputs[0].type` we
# will remove that `Alloc` later
assert i.type.ndim == cmp_op.ndim
if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape
cond = []
for idx in range(i.type.ndim):
if not i.type.broadcastable[idx] and not same_shape(
i, cmp_op, idx, idx
):
i_shp = get_shape(i, idx)
cmp_shp = get_shape(cmp_op, idx)
cond.append(eq(i_shp, cmp_shp))
if cond:
assert_op_in = assert_op(assert_op_in, *cond)
new_i.append(i.owner.inputs[0])
# Remove `Alloc` in `DimShuffle`
elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim
if config.experimental__local_alloc_elemwise_assert:
assert_cond = [
eq(i.shape[idx], cmp_op.shape[idx])
for idx in range(i.type.ndim)
if not i.type.broadcastable[idx]
and not same_shape(i, cmp_op, idx, idx)
]
if assert_cond:
assert_op_in = assert_op(assert_op_in, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim:
# The `Alloc` can add dimension to the value
# We add a `DimShuffle` to add them.
# We let later optimization merge the multiple `DimShuffle`
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)
return local_elemwise_alloc # We need to keep the `DimShuffle`. It could swap axes or
# add dimensions anywhere.
r_i = i.owner.op(alloc_input)
# Copy stack trace from i to new_i
copy_stack_trace(i, r_i)
new_i.append(r_i)
else:
new_i.append(i)
new_i[assert_op_idx] = assert_op_in
# TODO, global optimizer that lift the assert to the beginning of the graph. ret = node.op(*new_i, return_list=True)
# TODO, optimize all inputs when possible -- currently when all inputs have
# an alloc all but one is optimized.
local_elemwise_alloc = register_specialize( # Copy over stack trace from previous outputs to new outputs.
local_optimizer([Elemwise])(local_elemwise_alloc_op(Elemwise, Alloc, DimShuffle)), copy_stack_trace(node.outputs, ret)
"local_alloc_elemwise", return ret
)
@local_optimizer([Elemwise]) @local_optimizer([Elemwise])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论