提交 1c974f4f authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

- Only run local optimizers on the ops they are registered for.

- Fix existing optimizers in base code to register properly.
上级 ab660ca3
...@@ -736,6 +736,14 @@ class LocalOptimizer(object): ...@@ -736,6 +736,14 @@ class LocalOptimizer(object):
_optimizer_idx[0] += 1 _optimizer_idx[0] += 1
return self._optimizer_idx return self._optimizer_idx
def tracks(self):
"""
Return the list of op classes that this opt applies to.
Return None to apply to all nodes.
"""
return None
def transform(self, node): def transform(self, node):
"""Transform a subgraph whose output is `node`. """Transform a subgraph whose output is `node`.
...@@ -791,9 +799,15 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -791,9 +799,15 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
id(self)) id(self))
def local_optimizer(*tracks): def local_optimizer(tracks):
def decorator(f): def decorator(f):
"""WRITEME""" """WRITEME"""
if tracks is not None:
if len(tracks) is 0:
raise ValueError, ("Use None instead of an empty list to apply to all nodes.", f.__module__, f.__name__)
for t in tracks:
if not (isinstance(t, type) or isinstance(t, op.Op)):
raise ValueError, ("Tracks are op classes or instances", f.__module__, f.__name__)
rval = FromFunctionLocalOptimizer(f, tracks) rval = FromFunctionLocalOptimizer(f, tracks)
rval.__name__ = f.__name__ rval.__name__ = f.__name__
return rval return rval
...@@ -870,7 +884,7 @@ class OpSub(LocalOptimizer): ...@@ -870,7 +884,7 @@ class OpSub(LocalOptimizer):
return self.op1 return self.op1
def tracks(self): def tracks(self):
return [[self.op1]] return [self.op1]
def transform(self, node): def transform(self, node):
if node.op != self.op1: if node.op != self.op1:
...@@ -901,7 +915,7 @@ class OpRemove(LocalOptimizer): ...@@ -901,7 +915,7 @@ class OpRemove(LocalOptimizer):
return self.op return self.op
def tracks(self): def tracks(self):
return [[self.op]] return [self.op]
def transform(self, node): def transform(self, node):
if node.op != self.op: if node.op != self.op:
...@@ -1500,12 +1514,17 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1500,12 +1514,17 @@ class EquilibriumOptimizer(NavigatorOptimizer):
None, None,
ignore_newtrees=True, ignore_newtrees=True,
failure_callback=failure_callback) failure_callback=failure_callback)
self.local_optimizers = [] self.local_optimizers_map = dict()
self.local_optimizers_all = []
self.global_optimizers = [] self.global_optimizers = []
for opt in optimizers: for opt in optimizers:
if isinstance(opt, LocalOptimizer): if isinstance(opt, LocalOptimizer):
self.local_optimizers.append(opt) if opt.tracks is None:
self.local_optimizers_all.append(opt)
else:
for c in opt.tracks():
self.local_optimizers_map.setdefault(c, []).append(opt)
else: else:
self.global_optimizers.append(opt) self.global_optimizers.append(opt)
self.max_depth = max_depth self.max_depth = max_depth
...@@ -1513,10 +1532,21 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1513,10 +1532,21 @@ class EquilibriumOptimizer(NavigatorOptimizer):
assert self.max_use_ratio is not None, ( assert self.max_use_ratio is not None, (
'max_use_ratio has to be a number') 'max_use_ratio has to be a number')
def get_local_optimizers(self):
for opt in self.local_optimizers_all:
yield opt
# if repeat is not a problem we can drop the set
s = set()
for lopt in self.local_optimizers_map.values():
for opt in lopt:
if opt not in s:
yield opt
s.add(opt)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
super(EquilibriumOptimizer, self).add_requirements(fgraph) super(EquilibriumOptimizer, self).add_requirements(fgraph)
fgraph.attach_feature(ChangeTracker()) fgraph.attach_feature(ChangeTracker())
for opt in self.local_optimizers: for opt in self.get_local_optimizers():
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
for opt in self.global_optimizers: for opt in self.global_optimizers:
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
...@@ -1542,7 +1572,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1542,7 +1572,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
time_opts = {} time_opts = {}
io_toposort_timing = [] io_toposort_timing = []
nb_nodes = [] nb_nodes = []
for opt in self.global_optimizers + self.local_optimizers: for opt in self.global_optimizers + list(self.get_local_optimizers()):
global_process_count.setdefault(opt, 0) global_process_count.setdefault(opt, 0)
time_opts.setdefault(opt, 0) time_opts.setdefault(opt, 0)
...@@ -1595,7 +1625,9 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1595,7 +1625,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node = q.pop() node = q.pop()
current_node = node current_node = node
for lopt in self.local_optimizers: for lopt in (self.local_optimizers_all +
self.local_optimizers_map.get(type(node.op), []) +
self.local_optimizers_map.get(node.op, [])):
t_opt = time.time() t_opt = time.time()
lopt_change = self.process_node(fgraph, node, lopt) lopt_change = self.process_node(fgraph, node, lopt)
time_opts[lopt] += time.time() - t_opt time_opts[lopt] += time.time() - t_opt
...@@ -1634,7 +1666,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1634,7 +1666,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
print >> stream, "%s%s %s id=%i" % ( print >> stream, "%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self)) (' ' * level), self.__class__.__name__, name, id(self))
if depth != 0: if depth != 0:
for lopt in self.local_optimizers: for lopt in self.get_local_optimizers():
lopt.print_summary(stream, level=(level + 2), lopt.print_summary(stream, level=(level + 2),
depth=(depth - 1)) depth=(depth - 1))
...@@ -1654,7 +1686,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1654,7 +1686,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
start_nb_nodes, end_nb_nodes, max_nb_nodes) start_nb_nodes, end_nb_nodes, max_nb_nodes)
print >> stream, blanc, " time io_toposort %.3fs" % sum( print >> stream, blanc, " time io_toposort %.3fs" % sum(
io_toposort_timing) io_toposort_timing)
s = sum([time_opts[o] for o in opt.local_optimizers]) s = sum([time_opts[o] for o in opt.get_local_optimizers()])
print >> stream, blanc, " time in local optimizers %.3fs" % s print >> stream, blanc, " time in local optimizers %.3fs" % s
s = sum([time_opts[o] for o in opt.global_optimizers]) s = sum([time_opts[o] for o in opt.global_optimizers])
print >> stream, blanc, " time in global optimizers %.3fs" % s print >> stream, blanc, " time in global optimizers %.3fs" % s
...@@ -1679,7 +1711,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1679,7 +1711,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
not_used = 0 not_used = 0
not_used_time = 0 not_used_time = 0
process_count = {} process_count = {}
for o in opt.global_optimizers + opt.local_optimizers: for o in opt.global_optimizers + opt.get_local_optimizers():
process_count.setdefault(o, 0) process_count.setdefault(o, 0)
for count in loop_process_count: for count in loop_process_count:
for o, v in count.iteritems(): for o, v in count.iteritems():
...@@ -1707,8 +1739,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1707,8 +1739,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
#(opt, loop_timing, loop_process_count, max_nb_nodes, #(opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1 # global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers = set(prof1[0].local_optimizers).union( local_optimizers = set(prof1[0].get_local_optimizers()).union(
prof2[0].local_optimizers) prof2[0].get_local_optimizers())
global_optimizers = set(prof1[0].global_optimizers).union( global_optimizers = set(prof1[0].global_optimizers).union(
prof2[0].global_optimizers) prof2[0].global_optimizers)
new_opt = EquilibriumOptimizer( new_opt = EquilibriumOptimizer(
......
...@@ -49,7 +49,7 @@ def info(*msg): ...@@ -49,7 +49,7 @@ def info(*msg):
_logger.info('INFO theano.scan: ' + ' '.join(msg)) _logger.info('INFO theano.scan: ' + ' '.join(msg))
@gof.local_optimizer([None]) @gof.local_optimizer([scan_op.Scan])
def remove_constants_and_unused_inputs_scan(node): def remove_constants_and_unused_inputs_scan(node):
''' '''
Move constants into the inner graph, and remove unused inputs. Move constants into the inner graph, and remove unused inputs.
...@@ -1337,7 +1337,7 @@ def make_equiv(lo, li): ...@@ -1337,7 +1337,7 @@ def make_equiv(lo, li):
return left, right return left, right
@gof.local_optimizer([None]) @gof.local_optimizer([scan_op.Scan])
def scan_merge_inouts(node): def scan_merge_inouts(node):
if not isinstance(node.op, scan_op.Scan): if not isinstance(node.op, scan_op.Scan):
return False return False
......
...@@ -1645,7 +1645,7 @@ class Dot22(GemmRelated): ...@@ -1645,7 +1645,7 @@ class Dot22(GemmRelated):
_dot22 = Dot22() _dot22 = Dot22()
@local_optimizer([T._dot]) @local_optimizer([T.Dot])
def local_dot_to_dot22(node): def local_dot_to_dot22(node):
# This works for tensor.outer too because basic.outer is a macro that # This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below # produces a dot(dimshuffle,dimshuffle) of form 4 below
...@@ -2025,7 +2025,7 @@ blas_optdb.register('local_dot22_to_dot22scalar', ...@@ -2025,7 +2025,7 @@ blas_optdb.register('local_dot22_to_dot22scalar',
#from opt import register_specialize, register_canonicalize #from opt import register_specialize, register_canonicalize
#@register_specialize #@register_specialize
@local_optimizer([]) @local_optimizer([T.sub, T.add])
def local_print_as_we_go_along(node): def local_print_as_we_go_along(node):
if node.op in (T.sub, T.add): if node.op in (T.sub, T.add):
debugprint(node) debugprint(node)
...@@ -589,7 +589,7 @@ opt.local_mul_canonizer.add_simplifier(softmax_simplifier, ...@@ -589,7 +589,7 @@ opt.local_mul_canonizer.add_simplifier(softmax_simplifier,
if 0: if 0:
@opt.register_specialize @opt.register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([tensor.add])
def local_softmax_grad(node): def local_softmax_grad(node):
'''dy*sm - DimShuffle{0,'x'}(sum{1}(dy*sm))*sm -> softmax_grad(dy,sm)''' '''dy*sm - DimShuffle{0,'x'}(sum{1}(dy*sm))*sm -> softmax_grad(dy,sm)'''
#TODO what if the signs are changed? #TODO what if the signs are changed?
...@@ -1417,7 +1417,7 @@ def _is_const(z, val, approx=False): ...@@ -1417,7 +1417,7 @@ def _is_const(z, val, approx=False):
@opt.register_specialize @opt.register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([subtensor.AdvancedSubtensor])
def local_advanced_indexing_crossentropy_onehot(node): def local_advanced_indexing_crossentropy_onehot(node):
log = None log = None
sm = None sm = None
......
差异被折叠。
...@@ -816,7 +816,7 @@ def multinomial(random_state, size=None, n=1, pvals=[0.5, 0.5], ...@@ -816,7 +816,7 @@ def multinomial(random_state, size=None, n=1, pvals=[0.5, 0.5],
return op(random_state, size, n, pvals) return op(random_state, size, n, pvals)
@gof.local_optimizer([None]) @gof.local_optimizer([RandomFunction])
def random_make_inplace(node): def random_make_inplace(node):
op = node.op op = node.op
if isinstance(op, RandomFunction) and not op.inplace: if isinstance(op, RandomFunction) and not op.inplace:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论