提交 c86a717b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Simplify construction of aesara.tensor.basic_opt.local_elemwise_alloc

上级 91d149ab
......@@ -1470,185 +1470,173 @@ aesara.compile.mode.optdb.register(
aesara.compile.mode.optdb.register("UnShapeOpt", UnShapeOptimizer(), 10)
def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
def local_elemwise_alloc(fgraph, node):
"""
elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
-> elemwise(x, y.TensorType(BROADCAST CONDITION))
@register_specialize("local_alloc_elemwise")
@local_optimizer([Elemwise])
def local_elemwise_alloc(fgraph, node):
"""
elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
-> elemwise(x, y.TensorType(BROADCAST CONDITION))
elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION))
-> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION))
elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION))
-> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION))
BROADCAST CONDITION: the condition is that the one input that are
not to be optimized to have the same broadcast pattern as the
output.
BROADCAST CONDITION: the condition is that the one input that are
not to be optimized to have the same broadcast pattern as the
output.
We can change the alloc by a dimshuffle as the elemwise
already have the shape info. The dimshuffle will be faster
to exec.
We can change the `Alloc` by a `DimShuffle` as the `Elemwise` already have
the shape info. The `DimShuffle` will be faster to exec.
"""
if not isinstance(node.op, ElemwiseOP):
return False
TODO: Global optimizer that lifts the assert to the beginning of the graph?
TODO: Optimize all inputs when possible -- currently when all inputs have
an `Alloc` all but one is optimized.
if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true.
assert all(
[
o.type.broadcastable == node.outputs[0].type.broadcastable
for o in node.outputs[1:]
]
)
"""
if not isinstance(node.op, Elemwise):
return False
# The broadcast pattern of the output must match the broadcast
# pattern of at least one of the inputs.
if not any(
if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true.
assert all(
[
i.type.broadcastable == node.outputs[0].type.broadcastable
for i in node.inputs
o.type.broadcastable == node.outputs[0].type.broadcastable
for o in node.outputs[1:]
]
):
return False
def dimshuffled_alloc(i):
return (
isinstance(i.owner.op, DimShuffleOP)
and i.owner.inputs[0].owner
and isinstance(i.owner.inputs[0].owner.op, AllocOP)
)
)
# At least one input must have an owner that is either a AllocOP or a
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is
# nothing to optimize.
if not any(
[
i.owner and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
for i in node.inputs
]
):
return False
# The broadcast pattern of the output must match the broadcast
# pattern of at least one of the inputs.
if not any(
[
i.type.broadcastable == node.outputs[0].type.broadcastable
for i in node.inputs
]
):
return False
# Search for input that we can use as a baseline for the dimensions.
assert_op_idx = -1
for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a AllocOP nor a DimShuffleOP of a
# AllocOP so that all allocs can be optimized.
if not (
i.owner
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
):
assert_op_idx = idx
break
def dimshuffled_alloc(i):
return (
isinstance(i.owner.op, DimShuffle)
and i.owner.inputs[0].owner
and isinstance(i.owner.inputs[0].owner.op, Alloc)
)
# It may be the case that only AllocOP and DimShuffleOP of AllocOP exist.
if assert_op_idx < 0:
# We want to optimize as many allocs as possible. When
# there is more than one then do all but one. number of
# inputs with alloc or dimshuffle alloc
l2 = [
i
for i in node.inputs
if (
i.owner
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
)
]
# If only 1 alloc or dimshuffle alloc, it is the one we
# will use for the shape. So no alloc would be removed.
if len(l2) > 1:
# l contains inputs with alloc or dimshuffle alloc
# only. Its length will always be at least one, as we
# checked that before
l = [
idx
for idx, i in enumerate(node.inputs)
if i.broadcastable == node.outputs[0].broadcastable
]
assert_op_idx = l[0] # The first one is as good as any to use.
else:
# Nothing would be optimized!
return False
# At least one input must have an owner that is either a `Alloc` or a
# `DimShuffle` with an owner that is a `Alloc` -- otherwise there is
# nothing to optimize.
if not any(
[
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
for i in node.inputs
]
):
return False
assert_op_in = node.inputs[assert_op_idx]
cmp_op = assert_op_in
new_i = []
same_shape = fgraph.shape_feature.same_shape
for i in node.inputs:
# Remove alloc
if (
i.owner
and isinstance(i.owner.op, AllocOP)
and i.owner.inputs[0].type != i.owner.outputs[0].type
# Search for input that we can use as a baseline for the dimensions.
assert_op_idx = -1
for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a `Alloc` nor a `DimShuffle` of a
# `Alloc` so that all `Alloc`s can be optimized.
if not (
i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i))
):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we
# will remove that alloc later
assert i.type.ndim == cmp_op.ndim
if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape
cond = []
for idx in range(i.type.ndim):
if not i.type.broadcastable[idx] and not same_shape(
i, cmp_op, idx, idx
):
i_shp = get_shape(i, idx)
cmp_shp = get_shape(cmp_op, idx)
cond.append(eq(i_shp, cmp_shp))
if cond:
assert_op_in = assert_op(assert_op_in, *cond)
new_i.append(i.owner.inputs[0])
# Remove Alloc in DimShuffle
elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim
if config.experimental__local_alloc_elemwise_assert:
assert_cond = [
eq(i.shape[idx], cmp_op.shape[idx])
for idx in range(i.type.ndim)
if not i.type.broadcastable[idx]
and not same_shape(i, cmp_op, idx, idx)
]
if assert_cond:
assert_op_in = assert_op(assert_op_in, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim:
# The alloc can add dimension to the value
# We add a dimshuffle to add them.
# We let later optimization merge the multiple dimshuffle
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)
# We need to keep the dimshuffle. It could swap axes or
# add dimensions anywhere.
r_i = i.owner.op(alloc_input)
# Copy stack trace from i to new_i
copy_stack_trace(i, r_i)
new_i.append(r_i)
else:
new_i.append(i)
new_i[assert_op_idx] = assert_op_in
assert_op_idx = idx
break
ret = node.op(*new_i, return_list=True)
# It may be the case that only `Alloc` and `DimShuffle` of `Alloc` exist.
if assert_op_idx < 0:
# We want to optimize as many `Alloc`s as possible. When
# there is more than one then do all but one. number of
# inputs with `Alloc` or `DimShuffle` `Alloc`
l2 = [
i
for i in node.inputs
if (i.owner and (isinstance(i.owner.op, Alloc) or dimshuffled_alloc(i)))
]
# If only one `Alloc` or `DimShuffle` `Alloc`, it is the one we
# will use for the shape. So no `Alloc` would be removed.
if len(l2) > 1:
# One contains inputs with `Alloc` or `DimShuffle` `Alloc`
# only. Its length will always be at least one, as we
# checked that before
l = [
idx
for idx, i in enumerate(node.inputs)
if i.broadcastable == node.outputs[0].broadcastable
]
assert_op_idx = l[0] # The first one is as good as any to use.
else:
# Nothing would be optimized!
return False
# Copy over stack trace from previous outputs to new outputs.
copy_stack_trace(node.outputs, ret)
return ret
assert_op_in = node.inputs[assert_op_idx]
cmp_op = assert_op_in
new_i = []
same_shape = fgraph.shape_feature.same_shape
for i in node.inputs:
# Remove `Alloc`
if (
i.owner
and isinstance(i.owner.op, Alloc)
and i.owner.inputs[0].type != i.owner.outputs[0].type
):
# when `i.owner.inputs[0].type == i.owner.outputs[0].type` we
# will remove that `Alloc` later
assert i.type.ndim == cmp_op.ndim
if config.experimental__local_alloc_elemwise_assert:
get_shape = fgraph.shape_feature.get_shape
cond = []
for idx in range(i.type.ndim):
if not i.type.broadcastable[idx] and not same_shape(
i, cmp_op, idx, idx
):
i_shp = get_shape(i, idx)
cmp_shp = get_shape(cmp_op, idx)
cond.append(eq(i_shp, cmp_shp))
if cond:
assert_op_in = assert_op(assert_op_in, *cond)
new_i.append(i.owner.inputs[0])
# Remove `Alloc` in `DimShuffle`
elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim
if config.experimental__local_alloc_elemwise_assert:
assert_cond = [
eq(i.shape[idx], cmp_op.shape[idx])
for idx in range(i.type.ndim)
if not i.type.broadcastable[idx]
and not same_shape(i, cmp_op, idx, idx)
]
if assert_cond:
assert_op_in = assert_op(assert_op_in, *assert_cond)
alloc_input = i.owner.inputs[0].owner.inputs[0]
if alloc_input.ndim != i.owner.inputs[0].ndim:
# The `Alloc` can add dimension to the value
# We add a `DimShuffle` to add them.
# We let later optimization merge the multiple `DimShuffle`
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(
["x"] * nb_dim_to_add + list(range(alloc_input.ndim))
)
return local_elemwise_alloc
# We need to keep the `DimShuffle`. It could swap axes or
# add dimensions anywhere.
r_i = i.owner.op(alloc_input)
# Copy stack trace from i to new_i
copy_stack_trace(i, r_i)
new_i.append(r_i)
else:
new_i.append(i)
new_i[assert_op_idx] = assert_op_in
# TODO, global optimizer that lift the assert to the beginning of the graph.
# TODO, optimize all inputs when possible -- currently when all inputs have
# an alloc all but one is optimized.
ret = node.op(*new_i, return_list=True)
local_elemwise_alloc = register_specialize(
local_optimizer([Elemwise])(local_elemwise_alloc_op(Elemwise, Alloc, DimShuffle)),
"local_alloc_elemwise",
)
# Copy over stack trace from previous outputs to new outputs.
copy_stack_trace(node.outputs, ret)
return ret
@local_optimizer([Elemwise])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论