提交 5005974e authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Added a simple optimization for switch ops, if condition is a constant.

上级 5e51a8f2
...@@ -1445,6 +1445,30 @@ def local_join_1(node): ...@@ -1445,6 +1445,30 @@ def local_join_1(node):
############### ###############
# Switch opts # # Switch opts #
############### ###############
@register_canonicalize
@gof.local_optimizer([])
def local_remove_switch_const_cond(node):
"""
This optimization makes the following changes in the graph:
T.switch(cond,left,right) -->
if cond is constant and cond == 0: right
if cond is constant and cond != 0: left
"""
if ( isinstance(node.op, T.Elemwise) and
isinstance(node.op.scalar_op, scalar.basic.Switch)):
cond = T.extract_constant(node.inputs[0])
if type(cond) is numpy.ndarray and cond.ndim == 0:
if cond == 0:
return [node.inputs[2]]
else:
return [node.inputs[1]]
return False
return False
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.mul]) @gof.local_optimizer([T.mul])
def local_mul_switch_sink(node): def local_mul_switch_sink(node):
......
...@@ -2252,6 +2252,33 @@ class T_local_erfc(unittest.TestCase): ...@@ -2252,6 +2252,33 @@ class T_local_erfc(unittest.TestCase):
print t1-t0,t2-t1 print t1-t0,t2-t1
class test_local_remove_switch_const_cond(unittest.TestCase):
def test_const0(self):
x = theano.tensor.matrix('x', dtype='int64')
y = theano.tensor.matrix('y', dtype='int64')
z = theano.tensor.switch(0, x , y)
f = theano.function([x,y],z)
assert len([x for x in f.maker.env.toposort() if
isinstance(x,theano.tensor.Elemwise) ]) == 0
vx = numpy.array([[1,2,3],[ 4, 5, 6]], dtype='int64')
vy = numpy.array([[7,8,9],[10,11,12]], dtype='int64')
assert numpy.all(f(vx,vy) == vy)
def test_const1(self):
x = theano.tensor.matrix('x', dtype='int64')
y = theano.tensor.matrix('y', dtype='int64')
z = theano.tensor.switch(1, x , y)
f = theano.function([x,y],z)
assert len([x for x in f.maker.env.toposort() if
isinstance(x,theano.tensor.Elemwise) ]) == 0
vx = numpy.array([[1,2,3],[ 4, 5, 6]], dtype='int64')
vy = numpy.array([[7,8,9],[10,11,12]], dtype='int64')
assert numpy.all(f(vx,vy) == vx)
class T_local_sum(unittest.TestCase): class T_local_sum(unittest.TestCase):
def setUp(self): def setUp(self):
self.mode = theano.compile.get_default_mode().including('canonicalize') self.mode = theano.compile.get_default_mode().including('canonicalize')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论