提交 677cbcb8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #3758 from f0k/merge-switch-opt

Add optimizer merging two switches of same condition
...@@ -3828,6 +3828,32 @@ def local_div_switch_sink(node): ...@@ -3828,6 +3828,32 @@ def local_div_switch_sink(node):
return False return False
# Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
# condition, to enable further simplification of their branches
# Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
@register_canonicalize
@gof.local_optimizer([T.Elemwise])
def local_merge_switch_same_cond(node):
scal = theano.scalar
# node must be binary elemwise or add or mul
if not isinstance(node.op, T.Elemwise) or not isinstance(
node.op.scalar_op, (scal.BinaryScalarOp, scal.Add, scal.Mul)):
return
# all inputs must be switch
if not all(s.owner and isinstance(s.owner.op, T.Elemwise) and
isinstance(s.owner.op.scalar_op, scal.Switch)
for s in node.inputs):
return
# all switch conditions must be the same
cond = node.inputs[0].owner.inputs[0]
if not all(s.owner.inputs[0] is cond for s in node.inputs[1:]):
return
# pull out switch
return [T.switch(cond,
node.op(*[s.owner.inputs[1] for s in node.inputs]),
node.op(*[s.owner.inputs[2] for s in node.inputs]))]
############# #############
# Tile Opts # # Tile Opts #
############# #############
......
...@@ -4912,6 +4912,35 @@ class test_local_useless_switch(unittest.TestCase): ...@@ -4912,6 +4912,35 @@ class test_local_useless_switch(unittest.TestCase):
isinstance(node.op, theano.tensor.Elemwise)]) == 0 isinstance(node.op, theano.tensor.Elemwise)]) == 0
class test_local_merge_switch_same_cond(unittest.TestCase):
def test_elemwise(self):
# float Ops
mats = theano.tensor.matrices('cabxy')
c, a, b, x, y = mats
s1 = T.switch(c, a, b)
s2 = T.switch(c, x, y)
for op in (T.add, T.sub, T.mul, T.true_div, T.int_div, T.floor_div,
T.minimum, T.maximum, T.gt, T.lt, T.ge, T.le, T.eq, T.neq,
T.pow):
g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count('Switch') == 1
# integer Ops
mats = theano.tensor.imatrices('cabxy')
c, a, b, x, y = mats
s1 = T.switch(c, a, b)
s2 = T.switch(c, x, y)
for op in (T.and_, T.or_, T.xor,
T.bitwise_and, T.bitwise_or, T.bitwise_xor):
g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count('Switch') == 1
# add/mul with more than two inputs
u, v = theano.tensor.matrices('uv')
s3 = T.switch(c, u, v)
for op in (T.add, T.mul):
g = optimize(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
assert str(g).count('Switch') == 1
class T_local_sum_prod(unittest.TestCase): class T_local_sum_prod(unittest.TestCase):
""" """
Test sum/prod opts in opt.py Test sum/prod opts in opt.py
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论