提交 2fe4b0b8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #1594 from nouiz/faster_opt

[MRG] Fix test python 2.4 and faster opt
...@@ -283,7 +283,9 @@ The local version of the above code would be the following: ...@@ -283,7 +283,9 @@ The local version of the above code would be the following:
The definition of transform is the inner loop of the global optimizer, The definition of transform is the inner loop of the global optimizer,
where the node is given as argument. If no changes are to be made, where the node is given as argument. If no changes are to be made,
``False`` must be returned. Else, a list of what to replace the node's ``False`` must be returned. Else, a list of what to replace the node's
outputs with must be returned. outputs with must be returned. This list must have the same length as
node.ouputs. If one of node.outputs don't have clients(it is not used
in the graph), you can put None in the returned list to remove it.
In order to apply the local optimizer we must use it in conjunction In order to apply the local optimizer we must use it in conjunction
with a :ref:`navigator`. Basically, a :ref:`navigator` is a global with a :ref:`navigator`. Basically, a :ref:`navigator` is a global
......
...@@ -93,6 +93,7 @@ class FunctionGraph(utils.object2): ...@@ -93,6 +93,7 @@ class FunctionGraph(utils.object2):
inputs, outputs = graph.clone(inputs, outputs) inputs, outputs = graph.clone(inputs, outputs)
self.execute_callbacks_time = 0 self.execute_callbacks_time = 0
self.execute_callbacks_times = {}
if features is None: if features is None:
features = [] features = []
...@@ -507,7 +508,7 @@ class FunctionGraph(utils.object2): ...@@ -507,7 +508,7 @@ class FunctionGraph(utils.object2):
attach(self) attach(self)
except toolbox.AlreadyThere: except toolbox.AlreadyThere:
return return
self.execute_callbacks_times.setdefault(feature, 0)
#it would be nice if we could require a specific class instead of #it would be nice if we could require a specific class instead of
#a "workalike" so we could do actual error checking #a "workalike" so we could do actual error checking
#if not isinstance(feature, toolbox.Feature): #if not isinstance(feature, toolbox.Feature):
...@@ -549,8 +550,9 @@ class FunctionGraph(utils.object2): ...@@ -549,8 +550,9 @@ class FunctionGraph(utils.object2):
# try; the AttributeError reall must come from feature.${name} # try; the AttributeError reall must come from feature.${name}
# not existing # not existing
continue continue
tf0 = time.time()
fn(self, *args, **kwargs) fn(self, *args, **kwargs)
self.execute_callbacks_times[feature] += time.time() - tf0
self.execute_callbacks_time += time.time() - t0 self.execute_callbacks_time += time.time() - t0
def collect_callbacks(self, name, *args): def collect_callbacks(self, name, *args):
......
...@@ -495,14 +495,14 @@ def stack_search(start, expand, mode='bfs', build_inv=False): ...@@ -495,14 +495,14 @@ def stack_search(start, expand, mode='bfs', build_inv=False):
:param start: search from these nodes :param start: search from these nodes
:type expand: callable :type expand: callable
:param expand: :param expand:
when we get to a node, add expand(node) to the list of nodes to visit. This function when we get to a node, add expand(node) to the list of nodes to visit.
should return a list, or None This function should return a list, or None
:rtype: list of `Variable` or `Apply` instances (depends on `expend`) :rtype: list of `Variable` or `Apply` instances (depends on `expend`)
:return: the list of nodes in order of traversal. :return: the list of nodes in order of traversal.
:note: :note:
a node will appear at most once in the return value, even if it appears multiple times a node will appear at most once in the return value, even if it
in the start parameter. appears multiple times in the start parameter.
:postcondition: every element of start is transferred to the returned list. :postcondition: every element of start is transferred to the returned list.
:postcondition: start is empty. :postcondition: start is empty.
...@@ -549,9 +549,7 @@ def ancestors(variable_list, blockers=None): ...@@ -549,9 +549,7 @@ def ancestors(variable_list, blockers=None):
""" """
def expand(r): def expand(r):
if r.owner and (not blockers or r not in blockers): if r.owner and (not blockers or r not in blockers):
l = list(r.owner.inputs) return reversed(r.owner.inputs)
l.reverse()
return l
dfs_variables = stack_search(deque(variable_list), expand, 'dfs') dfs_variables = stack_search(deque(variable_list), expand, 'dfs')
return dfs_variables return dfs_variables
...@@ -801,7 +799,7 @@ def io_toposort(inputs, outputs, orderings=None): ...@@ -801,7 +799,7 @@ def io_toposort(inputs, outputs, orderings=None):
if isinstance(obj, Variable): if isinstance(obj, Variable):
if obj.owner: if obj.owner:
rval = [obj.owner] rval = [obj.owner]
if isinstance(obj, Apply): elif isinstance(obj, Apply):
rval = list(obj.inputs) rval = list(obj.inputs)
rval.extend(orderings.get(obj, [])) rval.extend(orderings.get(obj, []))
else: else:
......
...@@ -514,7 +514,8 @@ class MergeFeature(object): ...@@ -514,7 +514,8 @@ class MergeFeature(object):
continue continue
inputs_match = all(node_in is cand_in inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(node.inputs, candidate.inputs)) for node_in, cand_in in zip(node.inputs,
candidate.inputs))
if inputs_match and node.op == candidate.op: if inputs_match and node.op == candidate.op:
if (node, candidate) in self.blacklist: if (node, candidate) in self.blacklist:
# They were already tried, and there was an error # They were already tried, and there was an error
...@@ -566,6 +567,8 @@ class MergeOptimizer(Optimizer): ...@@ -566,6 +567,8 @@ class MergeOptimizer(Optimizer):
if fgraph.profile: if fgraph.profile:
validate_before = fgraph.profile.validate_time validate_before = fgraph.profile.validate_time
callback_before = fgraph.execute_callbacks_time callback_before = fgraph.execute_callbacks_time
callbacks_before = fgraph.execute_callbacks_times.copy()
nb_merged = 0 nb_merged = 0
nb_constant = 0 nb_constant = 0
while sched: while sched:
...@@ -589,20 +592,28 @@ class MergeOptimizer(Optimizer): ...@@ -589,20 +592,28 @@ class MergeOptimizer(Optimizer):
if fgraph.profile: if fgraph.profile:
validate_time = fgraph.profile.validate_time - validate_before validate_time = fgraph.profile.validate_time - validate_before
callback_time = fgraph.execute_callbacks_time - callback_before callback_time = fgraph.execute_callbacks_time - callback_before
callbacks_time = {}
for k, v in fgraph.execute_callbacks_times.iteritems():
if k in callbacks_before:
callbacks_time[k] = v - callbacks_before[k]
else:
callbacks_time[k] = v
else: else:
validate_time = None validate_time = None
callback_time = None callback_time = None
callbacks_time = {}
# clear blacklist # clear blacklist
fgraph.merge_feature.blacklist = [] fgraph.merge_feature.blacklist = []
return (nb_fail, time.time() - t0, validate_time, return (nb_fail, time.time() - t0, validate_time,
callback_time, nb_merged, nb_constant) callback_time, callbacks_time, nb_merged, nb_constant)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
nb_fail, replace_time, validate_time, callback_time, nb_merged, nb_constant = prof (nb_fail, replace_time, validate_time,
callback_time, callbacks_time, nb_merged, nb_constant) = prof
blanc = (' ' * level) blanc = (' ' * level)
print >> stream, blanc, "MergeOptimizer" print >> stream, blanc, "MergeOptimizer"
...@@ -610,6 +621,7 @@ class MergeOptimizer(Optimizer): ...@@ -610,6 +621,7 @@ class MergeOptimizer(Optimizer):
print >> stream, blanc, " replace_time", replace_time print >> stream, blanc, " replace_time", replace_time
print >> stream, blanc, " validate_time", validate_time print >> stream, blanc, " validate_time", validate_time
print >> stream, blanc, " callback_time", callback_time print >> stream, blanc, " callback_time", callback_time
print >> stream, blanc, " callback_times", callbacks_time
print >> stream, blanc, " nb_merged", nb_merged print >> stream, blanc, " nb_merged", nb_merged
print >> stream, blanc, " nb_constant", nb_constant print >> stream, blanc, " nb_constant", nb_constant
...@@ -800,8 +812,8 @@ class LocalOptGroup(LocalOptimizer): ...@@ -800,8 +812,8 @@ class LocalOptGroup(LocalOptimizer):
def __str__(self): def __str__(self):
return getattr(self, '__name__', return getattr(self, '__name__',
('<theano.gof.opt.LocalOptGroup instance>' ('<theano.gof.opt.LocalOptGroup instance>' +
+ str([str(o) for o in self.opts]))) str([str(o) for o in self.opts])))
def transform(self, node): def transform(self, node):
for opt in self.opts: for opt in self.opts:
...@@ -979,9 +991,9 @@ class PatternSub(LocalOptimizer): ...@@ -979,9 +991,9 @@ class PatternSub(LocalOptimizer):
else: else:
raise TypeError("The pattern to search for must start with " raise TypeError("The pattern to search for must start with "
"a specific Op instance.") "a specific Op instance.")
self.__doc__ = (self.__class__.__doc__ self.__doc__ = (self.__class__.__doc__ +
+ "\n\nThis instance does: " "\n\nThis instance does: " +
+ str(self) + "\n") str(self) + "\n")
self.allow_multiple_clients = allow_multiple_clients self.allow_multiple_clients = allow_multiple_clients
self.skip_identities_fn = skip_identities_fn self.skip_identities_fn = skip_identities_fn
if name: if name:
...@@ -1275,7 +1287,8 @@ class NavigatorOptimizer(Optimizer): ...@@ -1275,7 +1287,8 @@ class NavigatorOptimizer(Optimizer):
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, self.failure_callback(e, self,
[(x, None) for x in node.outputs], lopt) [(x, None) for x in node.outputs],
lopt)
return False return False
else: else:
raise raise
...@@ -1287,10 +1300,16 @@ class NavigatorOptimizer(Optimizer): ...@@ -1287,10 +1300,16 @@ class NavigatorOptimizer(Optimizer):
if len(node.outputs) != len(replacements): if len(node.outputs) != len(replacements):
raise ValueError('Optimizer %s gave wrong number of replacements' raise ValueError('Optimizer %s gave wrong number of replacements'
% lopt) % lopt)
# None in the replacement mean that this variable isn't used
# and we want to remove it
for r, rnew in zip(node.outputs, replacements):
if rnew is None and len(r.clients) > 0:
raise ValueError("A local optimizer tried to remove a Variable that is used")
# If an output would be replaced by itself, no need to perform # If an output would be replaced by itself, no need to perform
# the replacement # the replacement
repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements) repl_pairs = [(r, rnew) for r, rnew in zip(node.outputs, replacements)
if rnew is not r] if rnew is not r and rnew is not None]
if len(repl_pairs) == 0: if len(repl_pairs) == 0:
return False return False
try: try:
...@@ -1513,6 +1532,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1513,6 +1532,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
max_use_abort = False max_use_abort = False
opt_name = None opt_name = None
global_process_count = {} global_process_count = {}
start_nb_nodes = len(fgraph.apply_nodes)
max_nb_nodes = len(fgraph.apply_nodes) max_nb_nodes = len(fgraph.apply_nodes)
max_use = max_nb_nodes * self.max_use_ratio max_use = max_nb_nodes * self.max_use_ratio
...@@ -1597,13 +1617,16 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1597,13 +1617,16 @@ class EquilibriumOptimizer(NavigatorOptimizer):
loop_process_count.append(process_count) loop_process_count.append(process_count)
loop_timing.append(float(time.time() - t0)) loop_timing.append(float(time.time() - t0))
end_nb_nodes = len(fgraph.apply_nodes)
if max_use_abort: if max_use_abort:
_logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name _logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name
+ ". You can safely raise the current threshold of " + ". You can safely raise the current threshold of "
+ "%f with the theano flag 'optdb.max_use_ratio'." % + "%f with the theano flag 'optdb.max_use_ratio'." %
config.optdb.max_use_ratio) config.optdb.max_use_ratio)
return (self, loop_timing, loop_process_count, max_nb_nodes, return (self, loop_timing, loop_process_count,
(start_nb_nodes, end_nb_nodes, max_nb_nodes),
global_opt_timing, nb_nodes, time_opts, io_toposort_timing) global_opt_timing, nb_nodes, time_opts, io_toposort_timing)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
...@@ -1617,15 +1640,18 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1617,15 +1640,18 @@ class EquilibriumOptimizer(NavigatorOptimizer):
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
(opt, loop_timing, loop_process_count, max_nb_nodes, (opt, loop_timing, loop_process_count,
(start_nb_nodes, end_nb_nodes, max_nb_nodes),
global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof
blanc = (' ' * level) blanc = (' ' * level)
print >> stream, blanc, "EquilibriumOptimizer", print >> stream, blanc, "EquilibriumOptimizer",
print >> stream, blanc, getattr(opt, "name", print >> stream, blanc, getattr(opt, "name",
getattr(opt, "__name__", "")) getattr(opt, "__name__", ""))
print >> stream, blanc, " time %.3fs for %d passes, %d nodes max" % ( print >> stream, blanc, " time %.3fs for %d passes" % (
sum(loop_timing), len(loop_timing), max_nb_nodes) sum(loop_timing), len(loop_timing))
print >> stream, blanc, " nb nodes (start, end, max) %d %d %d" % (
start_nb_nodes, end_nb_nodes, max_nb_nodes)
print >> stream, blanc, " time io_toposort %.3fs" % sum( print >> stream, blanc, " time io_toposort %.3fs" % sum(
io_toposort_timing) io_toposort_timing)
s = sum([time_opts[o] for o in opt.local_optimizers]) s = sum([time_opts[o] for o in opt.local_optimizers])
......
...@@ -248,10 +248,11 @@ class ReplaceValidate(History, Validator): ...@@ -248,10 +248,11 @@ class ReplaceValidate(History, Validator):
raise ReplacementDidntRemovedError() raise ReplacementDidntRemovedError()
class NodeFinder(dict, Bookkeeper): class NodeFinder(Bookkeeper):
def __init__(self): def __init__(self):
self.fgraph = None self.fgraph = None
self.d = {}
def on_attach(self, fgraph): def on_attach(self, fgraph):
if self.fgraph is not None: if self.fgraph is not None:
...@@ -273,7 +274,7 @@ class NodeFinder(dict, Bookkeeper): ...@@ -273,7 +274,7 @@ class NodeFinder(dict, Bookkeeper):
def on_import(self, fgraph, node, reason): def on_import(self, fgraph, node, reason):
try: try:
self.setdefault(node.op, []).append(node) self.d.setdefault(node.op, []).append(node)
except TypeError: # node.op is unhashable except TypeError: # node.op is unhashable
return return
except Exception, e: except Exception, e:
...@@ -286,16 +287,16 @@ class NodeFinder(dict, Bookkeeper): ...@@ -286,16 +287,16 @@ class NodeFinder(dict, Bookkeeper):
def on_prune(self, fgraph, node, reason): def on_prune(self, fgraph, node, reason):
try: try:
nodes = self[node.op] nodes = self.d[node.op]
except TypeError: # node.op is unhashable except TypeError: # node.op is unhashable
return return
nodes.remove(node) nodes.remove(node)
if not nodes: if not nodes:
del self[node.op] del self.d[node.op]
def query(self, fgraph, op): def query(self, fgraph, op):
try: try:
all = self.get(op, []) all = self.d.get(op, [])
except TypeError: except TypeError:
raise TypeError("%s in unhashable and cannot be queried by the" raise TypeError("%s in unhashable and cannot be queried by the"
" optimizer" % op) " optimizer" % op)
......
...@@ -353,7 +353,7 @@ class Softmax(gof.Op): ...@@ -353,7 +353,7 @@ class Softmax(gof.Op):
x = tensor.as_tensor_variable(x) x = tensor.as_tensor_variable(x)
if x.type.ndim not in (1, 2) \ if x.type.ndim not in (1, 2) \
or x.type.dtype not in tensor.float_dtypes: or x.type.dtype not in tensor.float_dtypes:
raise ValueError('x must be 1-d or 2-d tensor of floats') raise ValueError('x must be 1-d or 2-d tensor of floats. Got ', x.type)
if x.ndim == 1: if x.ndim == 1:
x = tensor.shape_padleft(x, n_ones=1) x = tensor.shape_padleft(x, n_ones=1)
return Apply(self, [x], [x.type()]) return Apply(self, [x], [x.type()])
......
...@@ -915,6 +915,13 @@ class ShapeFeature(object): ...@@ -915,6 +915,13 @@ class ShapeFeature(object):
# If no info is known on r's shape, use other_shape # If no info is known on r's shape, use other_shape
self.set_shape(r, other_shape) self.set_shape(r, other_shape)
return return
if (other_r.owner and r.owner and
other_r.owner.inputs == r.owner.inputs and
other_r.owner.op == r.owner.op):
# We are doing a merge. So the 2 shapes graph will be the
# same. This is only a speed optimization to call
# ancestors() less frequently.
return
# Merge other_shape with r_shape, giving the priority to other_shape # Merge other_shape with r_shape, giving the priority to other_shape
merged_shape = [] merged_shape = []
...@@ -928,6 +935,18 @@ class ShapeFeature(object): ...@@ -928,6 +935,18 @@ class ShapeFeature(object):
# - Shape_i(i)(other_r); # - Shape_i(i)(other_r);
# - Shape_i(i)(r). # - Shape_i(i)(r).
merged_shape.append(r_shape[i]) merged_shape.append(r_shape[i])
elif isinstance(r_shape[i], (Constant, int)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape.append(r_shape[i])
elif isinstance(other_shape[i], (Constant, int)):
# We do this to call less often ancestors and make
# sure we have the simplest shape possible.
merged_shape.append(other_shape[i])
elif other_shape[i] == r_shape[i]:
# This mean the shape is equivalent
# We do not want to do the ancestor check in those cases
merged_shape.append(r_shape[i])
elif r_shape[i] in theano.gof.graph.ancestors([other_shape[i]]): elif r_shape[i] in theano.gof.graph.ancestors([other_shape[i]]):
# Another case where we want to use r_shape[i] is when # Another case where we want to use r_shape[i] is when
# other_shape[i] actually depends on r_shape[i]. In that case, # other_shape[i] actually depends on r_shape[i]. In that case,
......
...@@ -26,54 +26,33 @@ import logging ...@@ -26,54 +26,33 @@ import logging
_logger = logging.getLogger('theano.tensor.opt') _logger = logging.getLogger('theano.tensor.opt')
from theano import gof from theano import gof
from theano.compat.python2x import deque
from theano.tensor.elemwise import CAReduce from theano.tensor.elemwise import CAReduce
from theano.tensor import basic as T from theano.tensor import basic as T
from theano.gof.opt import Optimizer
from theano.gof import InconsistencyError, toolbox
from theano.tensor.basic import (get_scalar_constant_value, from theano.tensor.basic import (get_scalar_constant_value,
NotScalarConstantError) NotScalarConstantError)
from theano.tensor.opt import register_uncanonicalize from theano.tensor.opt import register_uncanonicalize
from theano import scalar as scal from theano import scalar as scal
class MaxAndArgmaxOptimizer(Optimizer): @register_uncanonicalize
"""Replace MaxAndArgmax by CAReduce when the argmax is not used @gof.local_optimizer([T._max_and_argmax])
def local_max_and_argmax(node):
This is faster as MaxAndArgmax don't have c code and execute it """
in two pass. If we don't use the argmax, change it to a max only.
""" """
def add_requirements(self, fgraph):
fgraph.attach_feature(toolbox.ReplaceValidate())
def apply(self, fgraph):
did_something = True
while did_something:
nodelist = fgraph.toposort()
did_something = False
for node in nodelist:
if node.op == T._max_and_argmax: if node.op == T._max_and_argmax:
if len(node.outputs[1].clients) == 0: if len(node.outputs[1].clients) == 0:
#MaxAndArgmax support variable axis,
#but CAReduce support only constant axis.
try: try:
axis = get_scalar_constant_value(node.inputs[1]) axis = get_scalar_constant_value(node.inputs[1])
except NotScalarConstantError: except NotScalarConstantError:
return False return False
new = CAReduce(scal.maximum, axis)(node.inputs[0]) new = CAReduce(scal.maximum, axis)(node.inputs[0])
try: return [new, None]
fgraph.replace_all_validate(
((node.outputs[0], new),),
reason=self.__class__.__name__)
did_something = True
break
except InconsistencyError, e:
pass
register_uncanonicalize(MaxAndArgmaxOptimizer(),
name='MaxAndArgmaxOptimizer')
@register_uncanonicalize @register_uncanonicalize
@gof.local_optimizer([T._shape]) @gof.local_optimizer([T._shape])
......
...@@ -3,7 +3,7 @@ import copy ...@@ -3,7 +3,7 @@ import copy
import numpy import numpy
import theano import theano
from theano.compat import PY3 from theano.compat import all, PY3
from theano.scalar import ComplexError, IntegerDivisionError from theano.scalar import ComplexError, IntegerDivisionError
from theano.gof import Constant, Variable from theano.gof import Constant, Variable
from theano.gof.utils import hashtype from theano.gof.utils import hashtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论