提交 054f21e6 authored 作者: Frederic's avatar Frederic

Speed up MaxAndArgmaxOptimizer.

This is done by making it a local optimizer. This change to allow local optimizer to remove variable in the graph. I allow this only for variable that aren't used. In a small test case, it wsa taking 0.6s now it take .1s. In a bigger test case, it was taking 40s. I didn't rerun it with the opt as it is too long. This speed up the opt, as we do less iteration on the graph, as in the EquilibriumOptimizer add in the current iteration the new node in the graph. The old opt wasn't doing this and I didn't wanted to duplicate all those type of code.
上级 36694a6d
...@@ -283,7 +283,9 @@ The local version of the above code would be the following: ...@@ -283,7 +283,9 @@ The local version of the above code would be the following:
The definition of transform is the inner loop of the global optimizer, The definition of transform is the inner loop of the global optimizer,
where the node is given as argument. If no changes are to be made, where the node is given as argument. If no changes are to be made,
``False`` must be returned. Else, a list of what to replace the node's ``False`` must be returned. Else, a list of what to replace the node's
outputs with must be returned. outputs with must be returned. This list must have the same length as
node.ouputs. If one of node.outputs don't have clients(it is not used
in the graph), you can put None in the returned list to remove it.
In order to apply the local optimizer we must use it in conjunction In order to apply the local optimizer we must use it in conjunction
with a :ref:`navigator`. Basically, a :ref:`navigator` is a global with a :ref:`navigator`. Basically, a :ref:`navigator` is a global
......
...@@ -1287,10 +1287,16 @@ class NavigatorOptimizer(Optimizer): ...@@ -1287,10 +1287,16 @@ class NavigatorOptimizer(Optimizer):
if len(node.outputs) != len(replacements): if len(node.outputs) != len(replacements):
raise ValueError('Optimizer %s gave wrong number of replacements' raise ValueError('Optimizer %s gave wrong number of replacements'
% lopt) % lopt)
# None in the replacement mean that this variable isn't used
# and we want to remove it
for r, rnew in zip(node.outputs, replacements):
if rnew is None and len(r.clients) > 0:
raise ValueError("A local optimizer tried to remove a Variable that is used")
# If an output would be replaced by itself, no need to perform # If an output would be replaced by itself, no need to perform
# the replacement # the replacement
repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements) repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements)
if rnew is not r] if rnew is not r and rnew is not None]
if len(repl_pairs) == 0: if len(repl_pairs) == 0:
return False return False
try: try:
......
...@@ -26,54 +26,31 @@ import logging ...@@ -26,54 +26,31 @@ import logging
_logger = logging.getLogger('theano.tensor.opt') _logger = logging.getLogger('theano.tensor.opt')
from theano import gof from theano import gof
from theano.compat.python2x import deque
from theano.tensor.elemwise import CAReduce from theano.tensor.elemwise import CAReduce
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.gof.opt import Optimizer
from theano.gof import InconsistencyError, toolbox
from theano.tensor.basic import (get_scalar_constant_value, from theano.tensor.basic import (get_scalar_constant_value,
NotScalarConstantError) NotScalarConstantError)
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal from theano import scalar as scal
class MaxAndArgmaxOptimizer(Optimizer): @register_uncanonicalize
"""Replace MaxAndArgmax by CAReduce when the argmax is not used @gof.local_optimizer([T._max_and_argmax])
def local_max_and_argmax(node):
This is faster as MaxAndArgmax don't have c code and execute it
in two pass.
""" """
If we don't use the argmax, change it to a max only.
def add_requirements(self, fgraph): """
fgraph.attach_feature(toolbox.ReplaceValidate()) if node.op == T._max_and_argmax:
if len(node.outputs[1].clients) == 0:
def apply(self, fgraph): try:
did_something = True axis = get_scalar_constant_value(node.inputs[1])
while did_something: except NotScalarConstantError:
nodelist = fgraph.toposort() return False
did_something = False
for node in nodelist: new = CAReduce(scal.maximum, axis)(node.inputs[0])
if node.op == T._max_and_argmax: return [new, None]
if len(node.outputs[1].clients) == 0:
try:
axis = get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError:
return False
new = CAReduce(scal.maximum, axis)(node.inputs[0])
try:
fgraph.replace_all_validate(
((node.outputs[0], new),),
reason=self.__class__.__name__)
did_something = True
break
except InconsistencyError, e:
pass
register_uncanonicalize(MaxAndArgmaxOptimizer(),
name='MaxAndArgmaxOptimizer')
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T._shape]) @gof.local_optimizer([T._shape])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论