提交 b811b4ab authored 作者: f0k's avatar f0k

More efficient implementation of switch merge optimizer (as suggested by nouiz)

上级 842dcea8
...@@ -3806,27 +3806,30 @@ def local_div_switch_sink(node): ...@@ -3806,27 +3806,30 @@ def local_div_switch_sink(node):
return False return False
# Merge add/sub/mul/div/minimum/maximum of two switches sharing the same # Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
# condition, to enable further simplification of their branches # condition, to enable further simplification of their branches
# Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y) # Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
# TODO: generalize to all elemwise Ops? generalize to Ops with 3+ inputs? @register_canonicalize
for _op_name in ('add', 'sub', 'mul', 'true_div', 'int_div', 'floor_div', @gof.local_optimizer([T.Elemwise])
'minimum', 'maximum', 'gt', 'lt', 'ge', 'le', 'eq', 'neq', def local_merge_switch_same_cond(node):
'and_', 'or_', 'xor', scal = theano.scalar
'bitwise_and', 'bitwise_or', 'bitwise_xor', 'pow'): # node must be binary elemwise or add or mul
_op = getattr(T, _op_name) if not isinstance(node.op, T.Elemwise) or not isinstance(
_opt_name = 'Merge %s of switch with same condition' % _op_name node.op.scalar_op, (scal.BinaryScalarOp, scal.Add, scal.Mul)):
_opt = gof.PatternSub( return
in_pattern=(_op, # all inputs must be switch
(T.switch, 'c', 'a1', 'b1'), if not all(s.owner and isinstance(s.owner.op, T.Elemwise) and
(T.switch, 'c', 'a2', 'b2')), isinstance(s.owner.op.scalar_op, scal.Switch)
out_pattern=(T.switch, 'c', for s in node.inputs):
(_op, 'a1', 'a2'), return
(_op, 'b1', 'b2')), # all switch conditions must be the same
name=_opt_name, cond = node.inputs[0].owner.inputs[0]
allow_multiple_clients=True) if not all(s.owner.inputs[0] is cond for s in node.inputs[1:]):
register_canonicalize(_opt, 'fast_run', name=_opt_name) return
del _op_name, _op, _opt_name, _opt # pull out switch
return [T.switch(cond,
node.op(*[s.owner.inputs[1] for s in node.inputs]),
node.op(*[s.owner.inputs[2] for s in node.inputs]))]
############# #############
......
...@@ -4914,7 +4914,7 @@ class test_local_useless_switch(unittest.TestCase): ...@@ -4914,7 +4914,7 @@ class test_local_useless_switch(unittest.TestCase):
isinstance(node.op, theano.tensor.Elemwise)]) == 0 isinstance(node.op, theano.tensor.Elemwise)]) == 0
class test_merge_switch_same_cond(unittest.TestCase): class test_local_merge_switch_same_cond(unittest.TestCase):
def test_elemwise(self): def test_elemwise(self):
# float Ops # float Ops
mats = theano.tensor.matrices('cabxy') mats = theano.tensor.matrices('cabxy')
...@@ -4935,6 +4935,12 @@ class test_merge_switch_same_cond(unittest.TestCase): ...@@ -4935,6 +4935,12 @@ class test_merge_switch_same_cond(unittest.TestCase):
T.bitwise_and, T.bitwise_or, T.bitwise_xor): T.bitwise_and, T.bitwise_or, T.bitwise_xor):
g = optimize(FunctionGraph(mats, [op(s1, s2)])) g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count('Switch') == 1 assert str(g).count('Switch') == 1
# add/mul with more than two inputs
u, v = theano.tensor.matrices('uv')
s3 = T.switch(c, u, v)
for op in (T.add, T.mul):
g = optimize(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
assert str(g).count('Switch') == 1
class T_local_sum_prod(unittest.TestCase): class T_local_sum_prod(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论