提交 2b350631 authored 作者: Frederic Bastien's avatar Frederic Bastien

add optimization -max(-x) -> min(x). Add test for min() and the new optimization.

上级 e032bb8f
...@@ -27,23 +27,12 @@ from basic import get_constant_value ...@@ -27,23 +27,12 @@ from basic import get_constant_value
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal from theano import scalar as scal
@register_uncanonicalize
@gof.local_optimizer([T._shape])
def local_max_and_argmax_specialize(node):
if node.op == T._max_and_argmax:
if len(node.outputs[1].clients)==0:
import pdb;pdb.set_trace()
try:
axis=get_constant_value(node.inputs[1])
except ValueError:
return False
return [CAReduce(scal.maximum,axis)(node.inputs[0]), T.as_tensor_variable(0)]
return False
class MaxAndArgmaxOptimizer(Optimizer): class MaxAndArgmaxOptimizer(Optimizer):
"""Graph optimizer for Fusion of elemwise operations""" """Replace MaxAndArgmax by CAReduce when the argmax is not used
This is faster as MaxAndArgmax don't have c code and execute it
in two pass.
"""
def add_requirements(self, env): def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate()) env.extend(toolbox.ReplaceValidate())
...@@ -73,3 +62,16 @@ class MaxAndArgmaxOptimizer(Optimizer): ...@@ -73,3 +62,16 @@ class MaxAndArgmaxOptimizer(Optimizer):
register_uncanonicalize(MaxAndArgmaxOptimizer(),name='MaxAndArgmaxOptimizer') register_uncanonicalize(MaxAndArgmaxOptimizer(),name='MaxAndArgmaxOptimizer')
@register_uncanonicalize
@gof.local_optimizer([T._shape])
def local_max_to_min(node):
if node.op == T.neg and node.inputs[0].owner:
max = node.inputs[0]
if max.owner and isinstance(max.owner.op, CAReduce) and max.owner.op.scalar_op==scal.maximum:
neg = max.owner.inputs[0]
if neg.owner and neg.owner.op == T.neg:
return [CAReduce(scal.minimum,max.owner.op.axis)(neg.owner.inputs[0])]
return False
import unittest
import numpy
from theano import function,config
import theano.tensor as tensor
#from theano.tensor import matrix,max_and_argmax,MaaxAndArgmax,neg
from theano.tensor.elemwise import CAReduce
from theano.tests import unittest_tools as utt
class T_max_and_argmax(unittest.TestCase):
def test_optimization(self):
#If we use only the max output, we should replace this op with a faster one.
data = numpy.asarray(numpy.random.rand(2,3),dtype=config.floatX)
n = tensor.matrix()
f = function([n], tensor.max_and_argmax(n,0)[0])
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op, CAReduce)
f = function([n], tensor.max_and_argmax(n,0))
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op, tensor.MaxAndArgmax)
class T_min_max(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_optimization_max(self):
data = numpy.asarray(numpy.random.rand(2,3),dtype=config.floatX)
n = tensor.matrix()
f = function([n],tensor.max(n,0))
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op,CAReduce)
f(data)
f = function([n],tensor.max(-n,0))
topo = f.maker.env.toposort()
assert len(topo)==2
assert topo[0].op==tensor.neg
assert isinstance(topo[1].op,CAReduce)
f(data)
f = function([n],-tensor.max(n,0))
topo = f.maker.env.toposort()
assert len(topo)==2
assert isinstance(topo[0].op,CAReduce)
assert topo[1].op==tensor.neg
f(data)
f = function([n],-tensor.max(-n,0))
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op,CAReduce)#min
f(data)
def test_optimization_min(self):
data = numpy.asarray(numpy.random.rand(2,3),dtype=config.floatX)
n = tensor.matrix()
f = function([n],tensor.min(n,0))
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op,CAReduce)
f(data)
#test variant with neg to make sure we optimize correctly
f = function([n],tensor.min(-n,0))
topo = f.maker.env.toposort()
assert len(topo)==2
assert isinstance(topo[0].op,CAReduce)#max
assert topo[1].op==tensor.neg
f(data)
f = function([n],-tensor.min(n,0))
topo = f.maker.env.toposort()
assert len(topo)==2
assert topo[0].op==tensor.neg
assert isinstance(topo[1].op,CAReduce)#max
f(data)
f = function([n],-tensor.min(-n,0))
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op,CAReduce)#max
f(data)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论