提交 78c2c35a authored 作者: Frederic Bastien's avatar Frederic Bastien

add a MaxAndArgmax optimization when argmax is not used.

上级 59654ff9
......@@ -4,6 +4,7 @@ __docformat__ = "restructuredtext en"
from basic import *
import opt
import opt_uncanonicalize
import blas
import xlogx
......
......@@ -212,21 +212,26 @@ def register_canonicalize(lopt, *tags, **kwargs):
compile.optdb['canonicalize'].register(name, lopt, 'fast_run', *tags)
return lopt
def register_stabilize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['stabilize'].register(name, lopt, 'fast_run', *tags)
return lopt
def register_specialize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['specialize'].register(name, lopt, 'fast_run', *tags)
return lopt
def register_specialize_device(lopt, *tags, **kwargs):
def register_uncanonicalize(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['specialize_device'].register(name, lopt, 'fast_run', *tags)
compile.optdb['uncanonicalize'].register(name, lopt, 'fast_run', *tags)
return lopt
def register_stabilize(lopt, *tags, **kwargs):
def register_specialize_device(lopt, *tags, **kwargs):
name = (kwargs and kwargs.pop('name')) or lopt.__name__
compile.optdb['stabilize'].register(name, lopt, 'fast_run', *tags)
compile.optdb['specialize_device'].register(name, lopt, 'fast_run', *tags)
return lopt
######################
# DimShuffle lifters #
######################
......
"""
This file implement specialization optimization that break the canonicalization form
"""
# TODO: intelligent merge for mul/add
# TODO: 0*x -> 0
import logging
_logger = logging.getLogger('theano.tensor.opt')
import operator
import itertools
import sys
import theano
from theano import gof
from elemwise import CAReduce
import basic as T
from theano.gof.python25 import any, all
from theano.gof.opt import Optimizer
from theano.gof import InconsistencyError, toolbox
from basic import get_constant_value
from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal
@register_uncanonicalize
@gof.local_optimizer([T._shape])
def local_max_and_argmax_specialize(node):
if node.op == T._max_and_argmax:
if len(node.outputs[1].clients)==0:
import pdb;pdb.set_trace()
try:
axis=get_constant_value(node.inputs[1])
except ValueError:
return False
return [CAReduce(scal.maximum,axis)(node.inputs[0]), T.as_tensor_variable(0)]
return False
class MaxAndArgmaxOptimizer(Optimizer):
"""Graph optimizer for Fusion of elemwise operations"""
def add_requirements(self, env):
env.extend(toolbox.ReplaceValidate())
def apply(self, env):
did_something = True
while did_something:
nodelist = list(env.nodes)
did_something = False
for node in nodelist:
if node.op == T._max_and_argmax:
if len(node.outputs[1].clients)==0:
try:
axis=get_constant_value(node.inputs[1])
except ValueError:
return False
new = CAReduce(scal.maximum,axis)(node.inputs[0])
try:
env.replace_all_validate(
((node.outputs[0],new),),
reason = self.__class__.__name__)
did_something = True
break
except InconsistencyError, e:
pass
register_uncanonicalize(MaxAndArgmaxOptimizer(),name='MaxAndArgmaxOptimizer')
......@@ -846,6 +846,16 @@ class T_max_and_argmax(unittest.TestCase):
v = eval_outputs(max_and_argmax(n,2)[0].shape)
assert tuple(v)==(2,3)
def test_optimization(self):
#If we use only the max output, we should replace this op with a faster one.
data = numpy.asarray(numpy.random.rand(2,3),dtype=config.floatX)
n = matrix()
f = function([n],max_and_argmax(n,0)[0])
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op,CAReduce)
def test_grad(self):
data = numpy.random.rand(2,3)
n = as_tensor_variable(data)
......@@ -996,6 +1006,16 @@ class T_max(unittest.TestCase):
v = eval_outputs(max(n,[0,1,2]).shape)
self.failUnless(v.size == 0)
def test_optimization(self):
data = numpy.asarray(numpy.random.rand(2,3),dtype=config.floatX)
n = matrix()
f = function([n],max(n,0))
topo = f.maker.env.toposort()
assert len(topo)==1
assert isinstance(topo[0].op,CAReduce)
f(data)
def _test_grad(self):
data = numpy.random.rand(2,3)
n = as_tensor_variable(data)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论