提交 595b29c0 authored 作者: Frederic's avatar Frederic

pep8

上级 3f04b1e4
...@@ -1606,53 +1606,55 @@ compile.optdb['specialize'].register('local_remove_all_assert', ...@@ -1606,53 +1606,55 @@ compile.optdb['specialize'].register('local_remove_all_assert',
local_remove_all_assert, local_remove_all_assert,
use_db_name_as_tag=False) use_db_name_as_tag=False)
def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
def local_elemwise_alloc(node): def local_elemwise_alloc(node):
""" """
elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION)) elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
-> elemwise(x, y.TensorType(BROADCAST CONDITION)) -> elemwise(x, y.TensorType(BROADCAST CONDITION))
elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION)) elemwise(dimshuffle(alloc(x, shp)),... ,y.TensorType(BROADCAST CONDITION))
-> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION)) -> elemwise(x.dimshuffle(...), y.TensorType(BROADCAST CONDITION))
BROADCAST CONDITION: the condition is that the one input that are BROADCAST CONDITION: the condition is that the one input that are
not to be optimized to have the same broadcast pattern as the not to be optimized to have the same broadcast pattern as the
output output
We can change the alloc by a dimshuffle as the elemwise We can change the alloc by a dimshuffle as the elemwise
already have the shape info. The dimshuffle will be faster already have the shape info. The dimshuffle will be faster
to exec to exec
""" """
if not isinstance(node.op, ElemwiseOP): if not isinstance(node.op, ElemwiseOP):
return False return False
if len(node.outputs) > 1: if len(node.outputs) > 1:
# Ensure all outputs have the same broadcast pattern # Ensure all outputs have the same broadcast pattern
# This is a supposition that I'm not sure is always true. # This is a supposition that I'm not sure is always true.
assert all([o.type.broadcastable == assert all([o.type.broadcastable ==
node.outputs[0].type.broadcastable for o in node.outputs[0].type.broadcastable for o in
node.outputs[1:]]) node.outputs[1:]])
# The broadcast pattern of the ouptut must match the broadcast pattern of # The broadcast pattern of the ouptut must match the broadcast
# at least one of the inputs. # pattern of at least one of the inputs.
if not any([i.type.broadcastable == if not any([i.type.broadcastable ==
node.outputs[0].type.broadcastable for i in node.inputs]): node.outputs[0].type.broadcastable for i in node.inputs]):
return False return False
def dimshuffled_alloc(i): def dimshuffled_alloc(i):
return (isinstance(i.owner.op, DimShuffleOP) and return (isinstance(i.owner.op, DimShuffleOP) and
i.owner.inputs[0].owner and i.owner.inputs[0].owner and
isinstance(i.owner.inputs[0].owner.op, AllocOP)) isinstance(i.owner.inputs[0].owner.op, AllocOP))
# At least one input must have an owner that is either a AllocOP or a # At least one input must have an owner that is either a AllocOP or a
# DimShuffleOP with an owner that is a AllocOP -- otherwise there is # DimShuffleOP with an owner that is a AllocOP -- otherwise there is
# nothing to optimize. # nothing to optimize.
if not any([i.owner if not any([i.owner
and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i)) and (isinstance(i.owner.op, AllocOP) or
dimshuffled_alloc(i))
for i in node.inputs]): for i in node.inputs]):
return False return False
## Search for input that we can use as a baseline for the dimensions. # Search for input that we can use as a baseline for the dimensions.
assert_op_idx = -1 assert_op_idx = -1
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable: if i.type.broadcastable == node.outputs[0].type.broadcastable:
...@@ -1663,47 +1665,48 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1663,47 +1665,48 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
or dimshuffled_alloc(i))): or dimshuffled_alloc(i))):
assert_op_idx = idx assert_op_idx = idx
break break
# It may be the case that only AllocOP and DimShuffleOP of AllocOP exist. # It may be the case that only AllocOP and DimShuffleOP of AllocOP exist.
if assert_op_idx < 0: if assert_op_idx < 0:
# We want to optimize as many allocs as possible. When there is more # We want to optimize as many allocs as possible. When
# than one then do all but one. # there is more than one then do all but one. number of
# number of inputs with alloc or dimshuffle alloc # inputs with alloc or dimshuffle alloc
l2 = [i for i in node.inputs l2 = [i for i in node.inputs
if (i.owner and (isinstance(i.owner.op, AllocOP) if (i.owner and (isinstance(i.owner.op, AllocOP)
or dimshuffled_alloc(i)))] or dimshuffled_alloc(i)))]
# If only 1 alloc or dimshuffle alloc, it is the one we will use for the shape # If only 1 alloc or dimshuffle alloc, it is the one we
# So no alloc would be removed. # will use for the shape. So no alloc would be removed.
if len(l2) > 1: if len(l2) > 1:
# l containt inputs with alloc or dimshuffle alloc only. # l containt inputs with alloc or dimshuffle alloc
# Its length will always be at least one, as we checked that before # only. Its length will always be at least one, as we
# checked that before
l = [idx for idx, i in enumerate(node.inputs) l = [idx for idx, i in enumerate(node.inputs)
if i.type.broadcastable == node.outputs[0].type.broadcastable] if i.broadcastable == node.outputs[0].broadcastable]
assert_op_idx = l[0] # The first one is as good as any to use. assert_op_idx = l[0] # The first one is as good as any to use.
else: else:
# Nothing would be optimized! # Nothing would be optimized!
return False return False
assert_op = node.inputs[assert_op_idx] assert_op = node.inputs[assert_op_idx]
cmp_op = assert_op cmp_op = assert_op
new_i = [] new_i = []
for i in node.inputs: for i in node.inputs:
# Remove alloc # Remove alloc
if (i.owner and isinstance(i.owner.op, AllocOP) if (i.owner and isinstance(i.owner.op, AllocOP)
and i.owner.inputs[0].type != i.owner.outputs[0].type): and i.owner.inputs[0].type != i.owner.outputs[0].type):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we # when i.owner.inputs[0].type == i.owner.outputs[0].type we
# will remove that alloc later # will remove that alloc later
assert i.type.ndim == cmp_op.ndim assert i.type.ndim == cmp_op.ndim
if (theano.config.experimental.local_alloc_elemwise_assert if (theano.config.experimental.local_alloc_elemwise_assert
and not node.fgraph.shape_feature.same_shape(i, cmp_op)): and not node.fgraph.shape_feature.same_shape(i, cmp_op)):
assert_op = assert_(assert_op, assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx]) *[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim) for idx in xrange(i.type.ndim)
if not i.type.broadcastable[idx]]) if not i.type.broadcastable[idx]])
new_i.append(i.owner.inputs[0]) new_i.append(i.owner.inputs[0])
# Remove Alloc in DimShuffle # Remove Alloc in DimShuffle
elif i.owner and dimshuffled_alloc(i): elif i.owner and dimshuffled_alloc(i):
assert i.type.ndim == cmp_op.type.ndim assert i.type.ndim == cmp_op.type.ndim
...@@ -1719,22 +1722,23 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP): ...@@ -1719,22 +1722,23 @@ def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
# We add a dimshuffle to add them. # We add a dimshuffle to add them.
# We let later optimization merge the multiple dimshuffle # We let later optimization merge the multiple dimshuffle
nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim nb_dim_to_add = i.owner.inputs[0].ndim - alloc_input.ndim
alloc_input = alloc_input.dimshuffle(['x'] * nb_dim_to_add + alloc_input = alloc_input.dimshuffle(
range(alloc_input.ndim)) ['x'] * nb_dim_to_add +
range(alloc_input.ndim))
# We need to keep the dimshuffle. It could swap axes or # We need to keep the dimshuffle. It could swap axes or
# add dimensions anywhere. # add dimensions anywhere.
new_i.append(i.owner.op(alloc_input)) new_i.append(i.owner.op(alloc_input))
else: else:
new_i.append(i) new_i.append(i)
new_i[assert_op_idx] = assert_op new_i[assert_op_idx] = assert_op
return node.op(*new_i, return_list=True) return node.op(*new_i, return_list=True)
return local_elemwise_alloc return local_elemwise_alloc
#TODO, global optimizer that lift the assert to the beginning of the graph. # TODO, global optimizer that lift the assert to the beginning of the graph.
#TODO, optimize all inputs when possible -- currently when all inputs have # TODO, optimize all inputs when possible -- currently when all inputs have
# an alloc all but one is optimized. # an alloc all but one is optimized.
local_elemwise_alloc = register_specialize( local_elemwise_alloc = register_specialize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论