提交 f251552e authored 作者: abergeron's avatar abergeron

Merge pull request #2371 from daemonmaker/local_alloc_elemwise2

Parameterized local_elemwise_alloc_opt for GPU support
...@@ -1606,9 +1606,8 @@ compile.optdb['specialize'].register('local_remove_all_assert', ...@@ -1606,9 +1606,8 @@ compile.optdb['specialize'].register('local_remove_all_assert',
local_remove_all_assert, local_remove_all_assert,
use_db_name_as_tag=False) use_db_name_as_tag=False)
@register_specialize("local_alloc_elemwise") def local_elemwise_alloc_op(ElemwiseOP, AllocOP, DimShuffleOP):
@gof.local_optimizer([T.Elemwise]) def local_elemwise_alloc(node):
def local_elemwise_alloc(node):
""" """
elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION)) elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
-> elemwise(x, y.TensorType(BROADCAST CONDITION)) -> elemwise(x, y.TensorType(BROADCAST CONDITION))
...@@ -1624,7 +1623,7 @@ def local_elemwise_alloc(node): ...@@ -1624,7 +1623,7 @@ def local_elemwise_alloc(node):
already have the shape info. The dimshuffle will be faster already have the shape info. The dimshuffle will be faster
to exec to exec
""" """
if not isinstance(node.op, T.Elemwise): if not isinstance(node.op, ElemwiseOP):
return False return False
if len(node.outputs) > 1: if len(node.outputs) > 1:
...@@ -1641,15 +1640,15 @@ def local_elemwise_alloc(node): ...@@ -1641,15 +1640,15 @@ def local_elemwise_alloc(node):
return False return False
def dimshuffled_alloc(i): def dimshuffled_alloc(i):
return (isinstance(i.owner.op, T.DimShuffle) and return (isinstance(i.owner.op, DimShuffleOP) and
i.owner.inputs[0].owner and \ i.owner.inputs[0].owner and
isinstance(i.owner.inputs[0].owner.op, T.Alloc)) isinstance(i.owner.inputs[0].owner.op, AllocOP))
# At least one input must have an owner that is either a T.Alloc or a # At least one input must have an owner that is either a AllocOP or a
# T.DimShuffle with an owner that is a T.Alloc -- otherwise there is # DimShuffleOP with an owner that is a AllocOP -- otherwise there is
# nothing to optimize. # nothing to optimize.
if not any([i.owner if not any([i.owner
and (isinstance(i.owner.op, T.Alloc) or dimshuffled_alloc(i)) and (isinstance(i.owner.op, AllocOP) or dimshuffled_alloc(i))
for i in node.inputs]): for i in node.inputs]):
return False return False
...@@ -1657,21 +1656,21 @@ def local_elemwise_alloc(node): ...@@ -1657,21 +1656,21 @@ def local_elemwise_alloc(node):
assert_op_idx = -1 assert_op_idx = -1
for idx, i in enumerate(node.inputs): for idx, i in enumerate(node.inputs):
if i.type.broadcastable == node.outputs[0].type.broadcastable: if i.type.broadcastable == node.outputs[0].type.broadcastable:
# Prefer an input that is not a T.Alloc nor a T.DimShuffle of a # Prefer an input that is not a AllocOP nor a DimShuffleOP of a
# T.Alloc so that all allocs can be optimized. # AllocOP so that all allocs can be optimized.
if not (i.owner if not (i.owner
and (isinstance(i.owner.op, T.Alloc) and (isinstance(i.owner.op, AllocOP)
or dimshuffled_alloc(i))): or dimshuffled_alloc(i))):
assert_op_idx = idx assert_op_idx = idx
break break
# It may be the case that only T.Allocs and T.DimShuffle of T.Allocs exist. # It may be the case that only AllocOP and DimShuffleOP of AllocOP exist.
if assert_op_idx < 0: if assert_op_idx < 0:
# We want to optimize as many allocs as possible. When there is more # We want to optimize as many allocs as possible. When there is more
# than one then do all but one. # than one then do all but one.
# number of inputs with alloc or dimshuffle alloc # number of inputs with alloc or dimshuffle alloc
l2 = [i for i in node.inputs l2 = [i for i in node.inputs
if (i.owner and (isinstance(i.owner.op, T.Alloc) if (i.owner and (isinstance(i.owner.op, AllocOP)
or dimshuffled_alloc(i)))] or dimshuffled_alloc(i)))]
# If only 1 alloc or dimshuffle alloc, it is the one we will use for the shape # If only 1 alloc or dimshuffle alloc, it is the one we will use for the shape
# So no alloc would be removed. # So no alloc would be removed.
...@@ -1691,7 +1690,7 @@ def local_elemwise_alloc(node): ...@@ -1691,7 +1690,7 @@ def local_elemwise_alloc(node):
for i in node.inputs: for i in node.inputs:
# Remove alloc # Remove alloc
if (i.owner and isinstance(i.owner.op, T.Alloc) if (i.owner and isinstance(i.owner.op, AllocOP)
and i.owner.inputs[0].type != i.owner.outputs[0].type): and i.owner.inputs[0].type != i.owner.outputs[0].type):
# when i.owner.inputs[0].type == i.owner.outputs[0].type we # when i.owner.inputs[0].type == i.owner.outputs[0].type we
# will remove that alloc later # will remove that alloc later
...@@ -1700,8 +1699,8 @@ def local_elemwise_alloc(node): ...@@ -1700,8 +1699,8 @@ def local_elemwise_alloc(node):
if (theano.config.experimental.local_alloc_elemwise_assert if (theano.config.experimental.local_alloc_elemwise_assert
and not node.fgraph.shape_feature.same_shape(i, cmp_op)): and not node.fgraph.shape_feature.same_shape(i, cmp_op)):
assert_op = assert_(assert_op, assert_op = assert_(assert_op,
*[T.eq(i.shape[idx], cmp_op.shape[idx])\ *[T.eq(i.shape[idx], cmp_op.shape[idx])
for idx in xrange(i.type.ndim) \ for idx in xrange(i.type.ndim)
if not i.type.broadcastable[idx]]) if not i.type.broadcastable[idx]])
new_i.append(i.owner.inputs[0]) new_i.append(i.owner.inputs[0])
...@@ -1732,10 +1731,16 @@ def local_elemwise_alloc(node): ...@@ -1732,10 +1731,16 @@ def local_elemwise_alloc(node):
return node.op(*new_i, return_list=True) return node.op(*new_i, return_list=True)
return local_elemwise_alloc
#TODO, global optimizer that lift the assert to the beginning of the graph. #TODO, global optimizer that lift the assert to the beginning of the graph.
#TODO, optimize all inputs when possible -- currently when all inputs have #TODO, optimize all inputs when possible -- currently when all inputs have
# an alloc all but one is optimized. # an alloc all but one is optimized.
local_elemwise_alloc = register_specialize(gof.local_optimizer([T.Elemwise])(
local_elemwise_alloc_op(T.Elemwise, T.Alloc, T.DimShuffle)
))
theano.configparser.AddConfigVar('experimental.local_alloc_elemwise', theano.configparser.AddConfigVar('experimental.local_alloc_elemwise',
"DEPRECATED: If True, enable the experimental" "DEPRECATED: If True, enable the experimental"
" optimization local_alloc_elemwise." " optimization local_alloc_elemwise."
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论