提交 b811b4ab authored 作者: f0k's avatar f0k

More efficient implementation of switch merge optimizer (as suggested by nouiz)

上级 842dcea8
......@@ -3806,27 +3806,30 @@ def local_div_switch_sink(node):
return False
# Merge add/sub/mul/div/minimum/maximum of two switches sharing the same
# Merge add/sub/mul/div/minimum/maximum/... of switches sharing the same
# condition, to enable further simplification of their branches
# Example: switch(c, a, b) + switch(c, x, y) -> switch(c, a+x, b+y)
# TODO: generalize to all elemwise Ops? generalize to Ops with 3+ inputs?
for _op_name in ('add', 'sub', 'mul', 'true_div', 'int_div', 'floor_div',
'minimum', 'maximum', 'gt', 'lt', 'ge', 'le', 'eq', 'neq',
'and_', 'or_', 'xor',
'bitwise_and', 'bitwise_or', 'bitwise_xor', 'pow'):
_op = getattr(T, _op_name)
_opt_name = 'Merge %s of switch with same condition' % _op_name
_opt = gof.PatternSub(
in_pattern=(_op,
(T.switch, 'c', 'a1', 'b1'),
(T.switch, 'c', 'a2', 'b2')),
out_pattern=(T.switch, 'c',
(_op, 'a1', 'a2'),
(_op, 'b1', 'b2')),
name=_opt_name,
allow_multiple_clients=True)
register_canonicalize(_opt, 'fast_run', name=_opt_name)
del _op_name, _op, _opt_name, _opt
@register_canonicalize
@gof.local_optimizer([T.Elemwise])
def local_merge_switch_same_cond(node):
scal = theano.scalar
# node must be binary elemwise or add or mul
if not isinstance(node.op, T.Elemwise) or not isinstance(
node.op.scalar_op, (scal.BinaryScalarOp, scal.Add, scal.Mul)):
return
# all inputs must be switch
if not all(s.owner and isinstance(s.owner.op, T.Elemwise) and
isinstance(s.owner.op.scalar_op, scal.Switch)
for s in node.inputs):
return
# all switch conditions must be the same
cond = node.inputs[0].owner.inputs[0]
if not all(s.owner.inputs[0] is cond for s in node.inputs[1:]):
return
# pull out switch
return [T.switch(cond,
node.op(*[s.owner.inputs[1] for s in node.inputs]),
node.op(*[s.owner.inputs[2] for s in node.inputs]))]
#############
......
......@@ -4914,7 +4914,7 @@ class test_local_useless_switch(unittest.TestCase):
isinstance(node.op, theano.tensor.Elemwise)]) == 0
class test_merge_switch_same_cond(unittest.TestCase):
class test_local_merge_switch_same_cond(unittest.TestCase):
def test_elemwise(self):
# float Ops
mats = theano.tensor.matrices('cabxy')
......@@ -4935,6 +4935,12 @@ class test_merge_switch_same_cond(unittest.TestCase):
T.bitwise_and, T.bitwise_or, T.bitwise_xor):
g = optimize(FunctionGraph(mats, [op(s1, s2)]))
assert str(g).count('Switch') == 1
# add/mul with more than two inputs
u, v = theano.tensor.matrices('uv')
s3 = T.switch(c, u, v)
for op in (T.add, T.mul):
g = optimize(FunctionGraph(mats + [u, v], [op(s1, s2, s3)]))
assert str(g).count('Switch') == 1
class T_local_sum_prod(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论