提交 aa49be05 authored 作者: Frederic's avatar Frederic 提交者: Amjad Almahairi

remove many useless elemwise.

They where discovered in the outer graph of scan.
上级 5ecbbde2
...@@ -3202,15 +3202,19 @@ def local_join_make_vector(node): ...@@ -3202,15 +3202,19 @@ def local_join_make_vector(node):
# Switch opts # # Switch opts #
############### ###############
@register_canonicalize @register_canonicalize('fast_compile')
@register_specialize
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
def local_remove_switch_const_cond(node): def local_useless_switch(node):
""" """
This optimization makes the following changes in the graph: This optimization makes the following changes in the graph:
T.switch(cond,left,right) --> T.switch(cond,left,right) -->
if cond is constant and cond == 0: right if cond is constant and cond == 0: right
if cond is constant and cond != 0: left if cond is constant and cond != 0: left
if left is right -> left
if left equal right -> left
T.switch(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
""" """
if (isinstance(node.op, T.Elemwise) and if (isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)): isinstance(node.op.scalar_op, scalar.basic.Switch)):
...@@ -3231,9 +3235,52 @@ def local_remove_switch_const_cond(node): ...@@ -3231,9 +3235,52 @@ def local_remove_switch_const_cond(node):
out = T.alloc(out, *[node.outputs[0].shape[i] for i out = T.alloc(out, *[node.outputs[0].shape[i] for i
in xrange(out.ndim)]) in xrange(out.ndim)])
return [out] return [out]
# if left is right -> left
if node.inputs[1] is node.inputs[2]:
return [node.inputs[1]]
# if left equal right -> left
if (T.extract_constant(node.inputs[1]) ==
T.extract_constant(node.inputs[2])):
if node.inputs[1].type == node.outputs[0].type:
return [node.inputs[1]]
if node.inputs[2].type == node.outputs[0].type:
return [node.inputs[2]]
# This case happen with scan.
# Elemwise{switch}(le(shape_i{id}(X), 0), 0, shape_i{id}(X)) -> shape_i{id}(X)
left = node.inputs[1]
right = node.inputs[2]
if (cond.owner and
isinstance(cond.owner.op, T.Elemwise) and
isinstance(cond.owner.op.scalar_op, scalar.LE) and
cond.owner.inputs[0].owner and
isinstance(cond.owner.inputs[0].owner.op, Shape_i) and
T.extract_constant(cond.owner.inputs[1]) == 0 and
T.extract_constant(left) == 0 and
right is cond.owner.inputs[0]):
assert right.type == node.outputs[0].type
return [right]
return False return False
return False return False
local_remove_switch_const_cond = local_useless_switch
#@register_canonicalize
#@register_specialize
@gof.local_optimizer([Shape_i])
def local_shape_i_infered(node):
if not isinstance(node.op, Shape_i):
return
if not hasattr(node, 'fgraph'):
return
if not hasattr(node.fgraph, 'shape_feature'):
return
try:
shp = node.fgraph.shape_feature.shape_of[node.inputs[0]][node.op.i]
c = get_scalar_constant_value(shp)
import pdb;pdb.set_trace()
return [T.constant(c, dtype=node.outputs[0].dtype)]
except NotScalarConstantError:
pass
@register_canonicalize @register_canonicalize
...@@ -4132,6 +4179,110 @@ def local_elemwise_sub_zeros(node): ...@@ -4132,6 +4179,110 @@ def local_elemwise_sub_zeros(node):
return [T.zeros_like(node.inputs[0])] return [T.zeros_like(node.inputs[0])]
@register_specialize
@register_stabilize
@register_canonicalize
@gof.local_optimizer([T.Elemwise])
def local_useless_elemwise_comparison(node):
"""...
:note: Those case appear in the graph generated around scan. This
don't remove much computation, but make the graph easier to
read.
# Comparing to itself is constant
Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
Elemwise[{minimum,maximum}](X, X) -> X
# Comparing shape to 0 can be constant
Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X)
Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
# Elemwise[minimum](X.shape[i], 0) -> 0
# Elemwise[minimum](0, X.shape[i]) -> 0
# The shape can be replaced with sum of shapes
Elemwise[LT](sum([anything that is shapes]), 0) -> Elemwise[zeros](X)
Elemwise[GE](sum([anything that is shapes]), 0) -> Elemwise[ones](X)
"""
if not isinstance(node.op, T.Elemwise):
return
if node.op.scalar_op.nin != 2:
return
# Elemwise[{LT,GT}](X, X) -> Elemwise[zeros](X)
if (isinstance(node.op.scalar_op, (scalar.LT, scalar.GT)) and
node.inputs[0] is node.inputs[1]):
return [T.zeros_like(node.outputs[0])]
# Elemwise[{LE,GE}](X, X) -> Elemwise[ones](X)
if (isinstance(node.op.scalar_op, (scalar.LE, scalar.GE)) and
node.inputs[0] is node.inputs[1]):
return [T.ones_like(node.inputs[0])]
# Elemwise[{minimum,maximum}](X, X) -> X
if (isinstance(node.op.scalar_op, (scalar.Minimum, scalar.Maximum)) and
node.inputs[0] is node.inputs[1]):
return [node.inputs[0]]
# Elemwise[LT](X.shape[i], 0) -> Elemwise[zeros](X)
if (isinstance(node.op.scalar_op, scalar.LT) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, Shape_i) and
T.extract_constant(node.inputs[1]) == 0):
return [T.zeros_like(node.outputs[0])]
# Elemwise[GE](X.shape[i], 0) -> Elemwise[ones](X)
if (isinstance(node.op.scalar_op, scalar.GE) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, Shape_i) and
T.extract_constant(node.inputs[1]) == 0):
return [T.ones_like(node.outputs[0])]
# Elemwise[maximum](X.shape[i], 0) -> X.shape[i]
if (isinstance(node.op.scalar_op, scalar.Maximum) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, Shape_i) and
T.extract_constant(node.inputs[1]) == 0):
return [node.inputs[0]]
# Elemwise[maximum](0, X.shape[i]) -> X.shape[i]
if (isinstance(node.op.scalar_op, scalar.Maximum) and
T.extract_constant(node.inputs[0]) == 0 and
node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, Shape_i)):
return [node.inputs[1]]
# Elemwise[minimum](X.shape[i], 0) -> 0
if (isinstance(node.op.scalar_op, scalar.Minimum) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, Shape_i) and
T.extract_constant(node.inputs[1]) == 0):
return [T.zeros_like(node.outputs[0])]
# Elemwise[minimum](0, X.shape[i]) -> 0
if (isinstance(node.op.scalar_op, scalar.Minimum) and
T.extract_constant(node.inputs[0]) == 0 and
node.inputs[1].owner and
isinstance(node.inputs[1].owner.op, Shape_i)):
return [T.zeros_like(node.outputs[0])]
# Elemwise[LT](sum([anything that is shapes]), 0) -> Elemwise[zeros](X)
if (isinstance(node.op.scalar_op, scalar.LT) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, Elemwise) and
isinstance(node.inputs[0].owner.op, scalar.Add) and
all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and
T.extract_constant(node.inputs[1]) == 0):
return [T.zeros_like(node.inputs[0], dtype=node.outputs[0].dtype)]
# Elemwise[GE](sum([anything that is shapes]), 0) -> Elemwise[ones](X)
if (isinstance(node.op.scalar_op, scalar.GE) and
node.inputs[0].owner and
isinstance(node.inputs[0].owner.op, Elemwise) and
isinstance(node.inputs[0].owner.op, scalar.Add) and
all([isinstance(var.owner and var.owner.op, Shape_i)
for var in node.inputs[0].owner.inputs]) and
T.extract_constant(node.inputs[1]) == 0):
return [T.ones_like(node.inputs[0], dtype=node.outputs[0].dtype)]
return
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([T.Sum, T.elemwise.Prod]) @gof.local_optimizer([T.Sum, T.elemwise.Prod])
......
...@@ -3136,6 +3136,22 @@ def test_local_fill_useless(): ...@@ -3136,6 +3136,22 @@ def test_local_fill_useless():
f(m_, x_) f(m_, x_)
def test_local_useless_elemwise_comparison():
# TODO: test each case individually.
# The following case is what made me discover those cases.
X = T.matrix('X')
Y = T.vector('Y')
X_sum, updates = theano.scan(fn=lambda x: x.sum(),
outputs_info=None,
sequences=[X],
non_sequences=None)
Z = X_sum + Y
theano.printing.debugprint(Z)
mode = theano.compile.get_default_mode().excluding('fusion')
f = theano.function([X, Y], Z, mode=mode)
theano.printing.debugprint(f, print_type=True)
class Test_local_useless_alloc(unittest.TestCase): class Test_local_useless_alloc(unittest.TestCase):
def setUp(self): def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed()) self.rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论