提交 595b29c0 authored 作者: Frederic's avatar Frederic

pep8

上级 3f04b1e4
...@@ -1606,6 +1606,7 @@ compile.optdb['specialize'].register('local_remove_all_assert', ...@@ -1606,6 +1606,7 @@ compile.optdb['specialize'].register('local_remove_all_assert',
local_remove_all_assert, local_remove_all_assert,
use_db_name_as_tag=False) use_db_name_as_tag=False)
def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
def local_elemwise_alloc(node): def local_elemwise_alloc(node):
""" """
...@@ -1633,8 +1634,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1633,8 +1634,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
node.outputs[0].type.broadcastable for o in node.outputs[0].type.broadcastable for o in
node.outputs[1:]]) node.outputs[1:]])
# The broadcast pattern of the ouptut must match the broadcast pattern of # The broadcast pattern of the ouptut must match the broadcast
# at least one of the inputs. # pattern of at least one of the inputs.
if not any([i.type.broadcastable == if not any([i.type.broadcastable ==
node.outputs[0].type.broadcastable for i in node.inputs]): node.outputs[0].type.broadcastable for i in node.inputs]):
return False return False
...@@ -1648,11 +1649,12 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1648,11 +1649,12 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is # DimShuffleOP with an owner that is a AllocOP -- otherwise there is
# nothing to optimize. # nothing to optimize.
if not any([i.owner if not any([i.owner
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i)) and (isinstance(i.owner.op, AllocOP) or
dimshuffled_alloc(i))
for i in node.inputs]): for i in node.inputs]):
return False return False
## Search for input that we can use as a baseline for the dimensions. # Search for input that we can use as a baseline for the dimensions.
assert_op_idx = -1 assert_op_idx = -1
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable: if i.type.broadcastable == node.outputs[0].type.broadcastable:
...@@ -1666,19 +1668,20 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1666,19 +1668,20 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# It may be the case that only AllocOP and DimShuffleOP of AllocOP exist. # It may be the case that only AllocOP and DimShuffleOP of AllocOP exist.
if assert_op_idx < 0: if assert_op_idx < 0:
# We want to optimize as many allocs as possible. When there is more # We want to optimize as many allocs as possible. When
# than one then do all but one. # there is more than one then do all but one. number of
# number of inputs with alloc or dimshuffle alloc # inputs with alloc or dimshuffle alloc
l2 = [i for i in node.inputs l2 = [i for i in node.inputs
if (i.owner and (isinstance(i.owner.op, AllocOP) if (i.owner and (isinstance(i.owner.op, AllocOP)
or dimshuffled_alloc(i)))] or dimshuffled_alloc(i)))]
# If only 1 alloc or dimshuffle alloc, it is the one we will use for the shape # If only 1 alloc or dimshuffle alloc, it is the one we
# So no alloc would be removed. # will use for the shape. So no alloc would be removed.
if len(l2) > 1: if len(l2) > 1:
# l containt inputs with alloc or dimshuffle alloc only. # l containt inputs with alloc or dimshuffle alloc
# Its length will always be at least one, as we checked that before # only. Its length will always be at least one, as we
# checked that before
l = [idx for idx, i in enumerate(node.inputs) l = [idx for idx, i in enumerate(node.inputs)
if i.type.broadcastable == node.outputs[0].type.broadcastable] if i.broadcastable == node.outputs[0].broadcastable]
assert_op_idx = l[0] # The first one is as good as any to use. assert_op_idx = l[0] # The first one is as good as any to use.
else: else:
# Nothing would be optimized! # Nothing would be optimized!
...@@ -1719,7 +1722,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1719,7 +1722,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# We add a dimshuffle to add them. # We add a dimshuffle to add them.
# We let later optimization merge the multiple dimshuffle # We let later optimization merge the multiple dimshuffle
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(['x'] * nb_dim_to_add + alloc_input = alloc_input.dimshuffle(
['x'] * nb_dim_to_add +
range(alloc_input.ndim)) range(alloc_input.ndim))
# We need to keep the dimshuffle. It could swap axes or # We need to keep the dimshuffle. It could swap axes or
...@@ -1733,8 +1737,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1733,8 +1737,8 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
return local_elemwise_alloc return local_elemwise_alloc
#TODO, global optimizer that lift the assert to the beginning of the graph. # TODO, global optimizer that lift the assert to the beginning of the graph.
#TODO, optimize all inputs when possible -- currently when all inputs have # TODO, optimize all inputs when possible -- currently when all inputs have
# an alloc all but one is optimized. # an alloc all but one is optimized.
local_elemwise_alloc = register_specialize( local_elemwise_alloc = register_specialize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论