提交 842dcea8 authored 作者: f0k's avatar f0k

Add optimizer merging two switches of same condition

上级 e521b20e
...@@ -3806,6 +3806,29 @@ def local_div_switch_sink(node): ...@@ -3806,6 +3806,29 @@ def local_div_switch_sink(node):
return False return False
# Merge add/sub/mul/div/minimum/maximum of two switches sharing the same
# condition, to enable further simplification of their branches
# Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
# TODO: generalize to all elemwise Ops? generalize to Ops with 3+ inputs?
for _op_name in ('add', 'sub', 'mul', 'true_div', 'int_div', 'floor_div',
'minimum', 'maximum', 'gt', 'lt', 'ge', 'le', 'eq', 'neq',
'and_', 'or_', 'xor',
'bitwise_and', 'bitwise_or', 'bitwise_xor', 'pow'):
_op = getattr(T, _op_name)
_opt_name = 'Merge %s of switch with same condition' % _op_name
_opt = gof.PatternSub(
in_pattern=(_op,
(T.switch, 'c', 'a1', 'b1'),
(T.switch, 'c', 'a2', 'b2')),
out_pattern=(T.switch, 'c',
(_op, 'a1', 'a2'),
(_op, 'b1', 'b2')),
name=_opt_name,
allow_multiple_clients=True)
register_canonicalize(_opt, 'fast_run', name=_opt_name)
del _op_name, _op, _opt_name, _opt
############# #############
# Tile Opts # # Tile Opts #
############# #############
......
...@@ -4914,6 +4914,29 @@ class test_local_useless_switch(unittest.TestCase): ...@@ -4914,6 +4914,29 @@ class test_local_useless_switch(unittest.TestCase):
isinstance(node.op, theano.tensor.Elemwise)]) == 0 isinstance(node.op, theano.tensor.Elemwise)]) == 0
class test_merge_switch_same_cond(unittest.TestCase):
def test_elemwise(self):
# float Ops
mats = theano.tensor.matrices('cabxy')
c, a, b, x, y = mats
s1 = T.switch(c, a, b)
s2 = T.switch(c, x, y)
for op in (T.add, T.sub, T.mul, T.true_div, T.int_div, T.floor_div,
T.minimum, T.maximum, T.gt, T.lt, T.ge, T.le, T.eq, T.neq,
T.pow):
g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count('Switch') == 1
# integer Ops
mats = theano.tensor.imatrices('cabxy')
c, a, b, x, y = mats
s1 = T.switch(c, a, b)
s2 = T.switch(c, x, y)
for op in (T.and_, T.or_, T.xor,
T.bitwise_and, T.bitwise_or, T.bitwise_xor):
g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count('Switch') == 1
class T_local_sum_prod(unittest.TestCase): class T_local_sum_prod(unittest.TestCase):
""" """
Test sum/prod opts in opt.py Test sum/prod opts in opt.py
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论