提交 2fe4b0b8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1594 from nouiz/faster_opt

[MRG] Fix test python 2.4 and faster opt
......@@ -283,7 +283,9 @@ The local version of the above code would be the following:
The definition of transform is the inner loop of the global optimizer,
where the node is given as argument. If no changes are to be made,
``False`` must be returned. Else, a list of what to replace the node's
outputs with must be returned.
outputs with must be returned. This list must have the same length as
node.ouputs. If one of node.outputs don't have clients(it is not used
in the graph), you can put None in the returned list to remove it.
In order to apply the local optimizer we must use it in conjunction
with a :ref:`navigator`. Basically, a :ref:`navigator` is a global
......
......@@ -93,6 +93,7 @@ class FunctionGraph(utils.object2):
inputs, outputs = graph.clone(inputs, outputs)
self.execute_callbacks_time = 0
self.execute_callbacks_times = {}
if features is None:
features = []
......@@ -507,7 +508,7 @@ class FunctionGraph(utils.object2):
attach(self)
except toolbox.AlreadyThere:
return
self.execute_callbacks_times.setdefault(feature, 0)
#it would be nice if we could require a specific class instead of
#a "workalike" so we could do actual error checking
#if not isinstance(feature, toolbox.Feature):
......@@ -549,8 +550,9 @@ class FunctionGraph(utils.object2):
# try; the AttributeError reall must come from feature.${name}
# not existing
continue
tf0 = time.time()
fn(self, *args, **kwargs)
self.execute_callbacks_times[feature] += time.time() - tf0
self.execute_callbacks_time += time.time() - t0
def collect_callbacks(self, name, *args):
......
......@@ -495,14 +495,14 @@ def stack_search(start, expand, mode='bfs', build_inv=False):
:param start: search from these nodes
:type expand: callable
:param expand:
when we get to a node, add expand(node) to the list of nodes to visit. This function
should return a list, or None
when we get to a node, add expand(node) to the list of nodes to visit.
This function should return a list, or None
:rtype: list of `Variable` or `Apply` instances (depends on `expend`)
:return: the list of nodes in order of traversal.
:note:
a node will appear at most once in the return value, even if it appears multiple times
in the start parameter.
a node will appear at most once in the return value, even if it
appears multiple times in the start parameter.
:postcondition: every element of start is transferred to the returned list.
:postcondition: start is empty.
......@@ -549,9 +549,7 @@ def ancestors(variable_list, blockers=None):
"""
def expand(r):
if r.owner and (not blockers or r not in blockers):
l = list(r.owner.inputs)
l.reverse()
return l
return reversed(r.owner.inputs)
dfs_variables = stack_search(deque(variable_list), expand, 'dfs')
return dfs_variables
......@@ -801,7 +799,7 @@ def io_toposort(inputs, outputs, orderings=None):
if isinstance(obj, Variable):
if obj.owner:
rval = [obj.owner]
if isinstance(obj, Apply):
elif isinstance(obj, Apply):
rval = list(obj.inputs)
rval.extend(orderings.get(obj, []))
else:
......
差异被折叠。
......@@ -248,10 +248,11 @@ class ReplaceValidate(History, Validator):
raise ReplacementDidntRemovedError()
class NodeFinder(dict, Bookkeeper):
class NodeFinder(Bookkeeper):
def __init__(self):
self.fgraph = None
self.d = {}
def on_attach(self, fgraph):
if self.fgraph is not None:
......@@ -273,7 +274,7 @@ class NodeFinder(dict, Bookkeeper):
def on_import(self, fgraph, node, reason):
try:
self.setdefault(node.op, []).append(node)
self.d.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable
return
except Exception, e:
......@@ -286,16 +287,16 @@ class NodeFinder(dict, Bookkeeper):
def on_prune(self, fgraph, node, reason):
try:
nodes = self[node.op]
nodes = self.d[node.op]
except TypeError: # node.op is unhashable
return
nodes.remove(node)
if not nodes:
del self[node.op]
del self.d[node.op]
def query(self, fgraph, op):
try:
all = self.get(op, [])
all = self.d.get(op, [])
except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the"
" optimizer" % op)
......
......@@ -353,7 +353,7 @@ class Softmax(gof.Op):
x = tensor.as_tensor_variable(x)
if x.type.ndim not in (1, 2) \
or x.type.dtype not in tensor.float_dtypes:
raise ValueError('x must be 1-d or 2-d tensor of floats')
raise ValueError('x must be 1-d or 2-d tensor of floats. Got ', x.type)
if x.ndim == 1:
x = tensor.shape_padleft(x, n_ones=1)
return Apply(self, [x], [x.type()])
......
......@@ -915,6 +915,13 @@ class ShapeFeature(object):
# If no info is known on r's shape, use other_shape
self.set_shape(r, other_shape)
return
if (other_r.owner and r.owner and
other_r.owner.inputs == r.owner.inputs and
other_r.owner.op == r.owner.op):
# We are doing a merge. So the 2 shapes graph will be the
# same. This is only a speed optimization to call
# ancestors() less frequently.
return
# Merge other_shape with r_shape, giving the priority to other_shape
merged_shape = []
......@@ -928,6 +935,18 @@ class ShapeFeature(object):
# - Shape_i(i)(other_r);
# - Shape_i(i)(r).
merged_shape.append(r_shape[i])
elif isinstance(r_shape[i], (Constant, int)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape.append(r_shape[i])
elif isinstance(other_shape[i], (Constant, int)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape.append(other_shape[i])
elif other_shape[i] == r_shape[i]:
# This mean the shape is equivalent
# We do not want to do the ancestor check in those cases
merged_shape.append(r_shape[i])
elif r_shape[i] in theano.gof.graph.ancestors([other_shape[i]]):
# Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case,
......
......@@ -26,54 +26,33 @@ import logging
_logger = logging.getLogger('theano.tensor.opt')
from theano import gof
from theano.compat.python2x import deque
from theano.tensor.elemwise import CAReduce
from theano.tensor import basic as T
from theano.gof.opt import Optimizer
from theano.gof import InconsistencyError, toolbox
from theano.tensor.basic import (get_scalar_constant_value,
NotScalarConstantError)
from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal
class MaxAndArgmaxOptimizer(Optimizer):
"""Replace MaxAndArgmax by CAReduce when the argmax is not used
This is faster as MaxAndArgmax don't have c code and execute it
in two pass.
@register_uncanonicalize
@gof.local_optimizer([T._max_and_argmax])
def local_max_and_argmax(node):
"""
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
def apply(self, fgraph):
did_something = True
while did_something:
nodelist = fgraph.toposort()
did_something = False
for node in nodelist:
if node.op == T._max_and_argmax:
if len(node.outputs[1].clients) == 0:
try:
axis = get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError:
return False
new = CAReduce(scal.maximum, axis)(node.inputs[0])
try:
fgraph.replace_all_validate(
((node.outputs[0], new),),
reason=self.__class__.__name__)
did_something = True
break
except InconsistencyError, e:
pass
register_uncanonicalize(MaxAndArgmaxOptimizer(),
name='MaxAndArgmaxOptimizer')
If we don't use the argmax, change it to a max only.
"""
if node.op == T._max_and_argmax:
if len(node.outputs[1].clients) == 0:
#MaxAndArgmax support variable axis,
#but CAReduce support only constant axis.
try:
axis = get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError:
return False
new = CAReduce(scal.maximum, axis)(node.inputs[0])
return [new, None]
@register_uncanonicalize
@gof.local_optimizer([T._shape])
......
......@@ -3,7 +3,7 @@ import copy
import numpy
import theano
from theano.compat import PY3
from theano.compat import all, PY3
from theano.scalar import ComplexError, IntegerDivisionError
from theano.gof import Constant, Variable
from theano.gof.utils import hashtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论