提交 595b29c0 authored 作者: Frederic's avatar Frederic

pep8

上级 3f04b1e4
......@@ -1606,6 +1606,7 @@ compile.optdb['specialize'].register('local_remove_all_assert',
local_remove_all_assert,
use_db_name_as_tag=False)
def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
def local_elemwise_alloc(node):
"""
......@@ -1633,8 +1634,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
node.outputs[0].type.broadcastable for o in
node.outputs[1:]])
# The broadcast pattern of the ouptut must match the broadcast pattern of
# at least one of the inputs.
# The broadcast pattern of the ouptut must match the broadcast
# pattern of at least one of the inputs.
if not any([i.type.broadcastable ==
node.outputs[0].type.broadcastable for i in node.inputs]):
return False
......@@ -1648,11 +1649,12 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is
# nothing to optimize.
if not any([i.owner
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
and (isinstance(i.owner.op, AllocOP) or
dimshuffled_alloc(i))
for i in node.inputs]):
return False
## Search for input that we can use as a baseline for the dimensions.
# Search for input that we can use as a baseline for the dimensions.
assert_op_idx = -1
for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable:
......@@ -1666,19 +1668,20 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# It may be the case that only AllocOP and DimShuffleOP of AllocOP exist.
if assert_op_idx < 0:
# We want to optimize as many allocs as possible. When there is more
# than one then do all but one.
# number of inputs with alloc or dimshuffle alloc
# We want to optimize as many allocs as possible. When
# there is more than one then do all but one. number of
# inputs with alloc or dimshuffle alloc
l2 = [i for i in node.inputs
if (i.owner and (isinstance(i.owner.op, AllocOP)
or dimshuffled_alloc(i)))]
# If only 1 alloc or dimshuffle alloc, it is the one we will use for the shape
# So no alloc would be removed.
# If only 1 alloc or dimshuffle alloc, it is the one we
# will use for the shape. So no alloc would be removed.
if len(l2) > 1:
# l containt inputs with alloc or dimshuffle alloc only.
# Its length will always be at least one, as we checked that before
# l containt inputs with alloc or dimshuffle alloc
# only. Its length will always be at least one, as we
# checked that before
l = [idx for idx, i in enumerate(node.inputs)
if i.type.broadcastable == node.outputs[0].type.broadcastable]
if i.broadcastable == node.outputs[0].broadcastable]
assert_op_idx = l[0] # The first one is as good as any to use.
else:
# Nothing would be optimized!
......@@ -1719,7 +1722,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# We add a dimshuffle to add them.
# We let later optimization merge the multiple dimshuffle
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(['x'] * nb_dim_to_add +
alloc_input = alloc_input.dimshuffle(
['x'] * nb_dim_to_add +
range(alloc_input.ndim))
# We need to keep the dimshuffle. It could swap axes or
......@@ -1733,8 +1737,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
return local_elemwise_alloc
#TODO, global optimizer that lift the assert to the beginning of the graph.
#TODO, optimize all inputs when possible -- currently when all inputs have
# TODO, global optimizer that lift the assert to the beginning of the graph.
# TODO, optimize all inputs when possible -- currently when all inputs have
# an alloc all but one is optimized.
local_elemwise_alloc = register_specialize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论