提交 9a694d42 authored 作者: Frederic Bastien's avatar Frederic Bastien

Added optimization of useless elemwise node. Refactored those optimization and…

Added optimization of useless elemwise node. Refactored those optimization and test to put them together.
上级 8bdd7392
...@@ -642,29 +642,33 @@ def local_subtensor_make_vector(node): ...@@ -642,29 +642,33 @@ def local_subtensor_make_vector(node):
_logger.error('failed to index with "%s"' % str(idx)) _logger.error('failed to index with "%s"' % str(idx))
raise raise
#TODO: the other optimization for and, or, xor, le and ge see ticket #496.
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
def local_useless_eq(node): def local_useless_elemwise(node):
"""eq(x,x) -> 1
""" """
if isinstance(node.op, T.Elemwise) and node.op.scalar_op == theano.scalar.eq and len(node.inputs)==2: eq(x,x) -> 1
if node.inputs[0]==node.inputs[1]: neq(x,x) -> 0
#it is the same var in the graph. That will always be true mul(x) -> x
return [T.fill(node.inputs[0], T.constant(1.0, dtype=node.outputs[0].type.dtype))] add(x) -> x
@register_canonicalize
@register_specialize
@gof.local_optimizer([T.Elemwise])
def local_useless_neq(node):
"""neq(x,x) -> 0
""" """
if isinstance(node.op, T.Elemwise) and node.op.scalar_op == theano.scalar.neq and len(node.inputs)==2: if isinstance(node.op, T.Elemwise):
if node.inputs[0]==node.inputs[1]: if node.op.scalar_op == theano.scalar.eq and len(node.inputs)==2:
if node.inputs[0]==node.inputs[1]:
#it is the same var in the graph. That will always be true #it is the same var in the graph. That will always be true
return [T.fill(node.inputs[0], T.constant(0.0, dtype=node.outputs[0].type.dtype))] return [T.fill(node.inputs[0], T.constant(1.0, dtype=node.outputs[0].type.dtype))]
if node.op.scalar_op == theano.scalar.neq and len(node.inputs)==2:
if node.inputs[0]==node.inputs[1]:
#it is the same var in the graph. That will always be false
return [T.fill(node.inputs[0], T.constant(0.0, dtype=node.outputs[0].type.dtype))]
if node.op.scalar_op == theano.scalar.mul and len(node.inputs)==1:
return [node.inputs[0]]
if node.op.scalar_op == theano.scalar.add and len(node.inputs)==1:
return [node.inputs[0]]
#TODO: the other optimization for and, or, xor, le and ge see ticket #496.
@register_specialize @register_specialize
@gof.local_optimizer([T.Elemwise]) @gof.local_optimizer([T.Elemwise])
......
...@@ -1480,44 +1480,80 @@ class T_Rebroadcast(unittest.TestCase): ...@@ -1480,44 +1480,80 @@ class T_Rebroadcast(unittest.TestCase):
assert len(rebroadcast_nodes) == 1 assert len(rebroadcast_nodes) == 1
assert rebroadcast_nodes[0].op.axis == {0: True} assert rebroadcast_nodes[0].op.axis == {0: True}
def test_local_useless_eq(): class T_useless_elemwise(unittest.TestCase):
mode = theano.compile.get_default_mode().including('canonicalize') def setUp(self):
x=T.dmatrix() self.mode = theano.compile.get_default_mode().including('canonicalize')
y=T.dmatrix()
f=theano.function([x,y],T.eq(x,y), mode=mode) def test_eq(self):
vx=numpy.random.rand(5,4) x=T.dmatrix()
vy=numpy.random.rand(5,4) y=T.dmatrix()
f(vx,vy) f=theano.function([x,y],T.eq(x,y), mode=self.mode)
topo = f.maker.env.toposort() vx=numpy.random.rand(5,4)
assert len(topo)==1 vy=numpy.random.rand(5,4)
assert isinstance(topo[0].op,T.Elemwise) f(vx,vy)
assert isinstance(topo[0].op.scalar_op,theano.scalar.EQ) topo = f.maker.env.toposort()
f2=theano.function([x],T.eq(x,x), mode=mode) assert len(topo)==1
assert numpy.all(f2(vx)==numpy.ones((5,4))) assert isinstance(topo[0].op,T.Elemwise)
topo2 = f2.maker.env.toposort() assert isinstance(topo[0].op.scalar_op,theano.scalar.EQ)
print topo2 f2=theano.function([x],T.eq(x,x), mode=self.mode)
#Shape_i{1}(<TensorType(float64, matrix)>), Shape_i{0}(<TensorType(float64, matrix)>), Alloc([[1]], Shape_i{0}.0, Shape_i{1}.0 assert numpy.all(f2(vx)==numpy.ones((5,4)))
assert len(topo2)==3 topo2 = f2.maker.env.toposort()
assert isinstance(topo2[-1].op,T.Alloc) print topo2
#Shape_i{1}(<TensorType(float64, matrix)>), Shape_i{0}(<TensorType(float64, matrix)>), Alloc([[1]], Shape_i{0}.0, Shape_i{1}.0
def test_local_useless_neq(): assert len(topo2)==3
mode = theano.compile.get_default_mode().including('canonicalize') assert isinstance(topo2[-1].op,T.Alloc)
x=T.dmatrix()
y=T.dmatrix() def test_neq(self):
f=theano.function([x,y],T.neq(x,y), mode=mode) x=T.dmatrix()
vx=numpy.random.rand(5,4) y=T.dmatrix()
vy=numpy.random.rand(5,4) f=theano.function([x,y],T.neq(x,y), mode=self.mode)
f(vx,vy) vx=numpy.random.rand(5,4)
topo = f.maker.env.toposort() vy=numpy.random.rand(5,4)
assert len(topo)==1 f(vx,vy)
assert isinstance(topo[0].op,T.Elemwise) topo = f.maker.env.toposort()
assert isinstance(topo[0].op.scalar_op,theano.scalar.NEQ) assert len(topo)==1
f2=theano.function([x],T.neq(x,x), mode=mode) assert isinstance(topo[0].op,T.Elemwise)
assert numpy.all(f2(vx)==numpy.zeros((5,4))) assert isinstance(topo[0].op.scalar_op,theano.scalar.NEQ)
topo2 = f2.maker.env.toposort() f2=theano.function([x],T.neq(x,x), mode=self.mode)
print topo2 assert numpy.all(f2(vx)==numpy.zeros((5,4)))
assert len(topo2)==3 topo2 = f2.maker.env.toposort()
assert isinstance(topo2[-1].op,T.Alloc) print topo2
assert len(topo2)==3
assert isinstance(topo2[-1].op,T.Alloc)
def test_mul(self):
x=T.dmatrix()
y=T.dmatrix()
f=theano.function([x],T.mul(x), mode=self.mode)
vx=numpy.random.rand(5,4)
vy=numpy.random.rand(5,4)
f(vx)
topo = f.maker.env.toposort()
assert len(topo)==0
f2=theano.function([x,y],T.mul(x,y), mode=self.mode)
assert numpy.all(f2(vx,vy)==vx*vy)
topo2 = f2.maker.env.toposort()
print topo2
assert len(topo2)==1
assert isinstance(topo2[0].op,T.Elemwise)
assert isinstance(topo2[0].op.scalar_op,theano.scalar.Mul)
def test_add(self):
x=T.dmatrix()
y=T.dmatrix()
f=theano.function([x],T.add(x), mode=self.mode)
vx=numpy.random.rand(5,4)
vy=numpy.random.rand(5,4)
f(vx)
topo = f.maker.env.toposort()
assert len(topo)==0
f2=theano.function([x,y],T.add(x,y), mode=self.mode)
assert numpy.all(f2(vx,vy)==vx+vy)
topo2 = f2.maker.env.toposort()
print topo2
assert len(topo2)==1
assert isinstance(topo2[0].op,T.Elemwise)
assert isinstance(topo2[0].op.scalar_op,theano.scalar.Add)
def test_constant_get_stabilized(): def test_constant_get_stabilized():
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论