提交 1c974f4f authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

- Only run local optimizers on the ops they are registered for.

- Fix existing optimizers in base code to register properly.
上级 ab660ca3
...@@ -736,6 +736,14 @@ class LocalOptimizer(object): ...@@ -736,6 +736,14 @@ class LocalOptimizer(object):
_optimizer_idx[0] += 1 _optimizer_idx[0] += 1
return self._optimizer_idx return self._optimizer_idx
def tracks(self):
"""
Return the list of op classes that this opt applies to.
Return None to apply to all nodes.
"""
return None
def transform(self, node): def transform(self, node):
"""Transform a subgraph whose output is `node`. """Transform a subgraph whose output is `node`.
...@@ -791,9 +799,15 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -791,9 +799,15 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
id(self)) id(self))
def local_optimizer(*tracks): def local_optimizer(tracks):
def decorator(f): def decorator(f):
"""WRITEME""" """WRITEME"""
if tracks is not None:
if len(tracks) is 0:
raise ValueError, ("Use None instead of an empty list to apply to all nodes.", f.__module__, f.__name__)
for t in tracks:
if not (isinstance(t, type) or isinstance(t, op.Op)):
raise ValueError, ("Tracks are op classes or instances", f.__module__, f.__name__)
rval = FromFunctionLocalOptimizer(f, tracks) rval = FromFunctionLocalOptimizer(f, tracks)
rval.__name__ = f.__name__ rval.__name__ = f.__name__
return rval return rval
...@@ -870,7 +884,7 @@ class OpSub(LocalOptimizer): ...@@ -870,7 +884,7 @@ class OpSub(LocalOptimizer):
return self.op1 return self.op1
def tracks(self): def tracks(self):
return [[self.op1]] return [self.op1]
def transform(self, node): def transform(self, node):
if node.op != self.op1: if node.op != self.op1:
...@@ -901,7 +915,7 @@ class OpRemove(LocalOptimizer): ...@@ -901,7 +915,7 @@ class OpRemove(LocalOptimizer):
return self.op return self.op
def tracks(self): def tracks(self):
return [[self.op]] return [self.op]
def transform(self, node): def transform(self, node):
if node.op != self.op: if node.op != self.op:
...@@ -1500,12 +1514,17 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1500,12 +1514,17 @@ class EquilibriumOptimizer(NavigatorOptimizer):
None, None,
ignore_newtrees=True, ignore_newtrees=True,
failure_callback=failure_callback) failure_callback=failure_callback)
self.local_optimizers = [] self.local_optimizers_map = dict()
self.local_optimizers_all = []
self.global_optimizers = [] self.global_optimizers = []
for opt in optimizers: for opt in optimizers:
if isinstance(opt, LocalOptimizer): if isinstance(opt, LocalOptimizer):
self.local_optimizers.append(opt) if opt.tracks is None:
self.local_optimizers_all.append(opt)
else:
for c in opt.tracks():
self.local_optimizers_map.setdefault(c, []).append(opt)
else: else:
self.global_optimizers.append(opt) self.global_optimizers.append(opt)
self.max_depth = max_depth self.max_depth = max_depth
...@@ -1513,10 +1532,21 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1513,10 +1532,21 @@ class EquilibriumOptimizer(NavigatorOptimizer):
assert self.max_use_ratio is not None, ( assert self.max_use_ratio is not None, (
'max_use_ratio has to be a number') 'max_use_ratio has to be a number')
def get_local_optimizers(self):
for opt in self.local_optimizers_all:
yield opt
# if repeat is not a problem we can drop the set
s = set()
for lopt in self.local_optimizers_map.values():
for opt in lopt:
if opt not in s:
yield opt
s.add(opt)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
super(EquilibriumOptimizer, self).add_requirements(fgraph) super(EquilibriumOptimizer, self).add_requirements(fgraph)
fgraph.attach_feature(ChangeTracker()) fgraph.attach_feature(ChangeTracker())
for opt in self.local_optimizers: for opt in self.get_local_optimizers():
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
for opt in self.global_optimizers: for opt in self.global_optimizers:
opt.add_requirements(fgraph) opt.add_requirements(fgraph)
...@@ -1542,7 +1572,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1542,7 +1572,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
time_opts = {} time_opts = {}
io_toposort_timing = [] io_toposort_timing = []
nb_nodes = [] nb_nodes = []
for opt in self.global_optimizers + self.local_optimizers: for opt in self.global_optimizers + list(self.get_local_optimizers()):
global_process_count.setdefault(opt, 0) global_process_count.setdefault(opt, 0)
time_opts.setdefault(opt, 0) time_opts.setdefault(opt, 0)
...@@ -1595,7 +1625,9 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1595,7 +1625,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node = q.pop() node = q.pop()
current_node = node current_node = node
for lopt in self.local_optimizers: for lopt in (self.local_optimizers_all +
self.local_optimizers_map.get(type(node.op), []) +
self.local_optimizers_map.get(node.op, [])):
t_opt = time.time() t_opt = time.time()
lopt_change = self.process_node(fgraph, node, lopt) lopt_change = self.process_node(fgraph, node, lopt)
time_opts[lopt] += time.time() - t_opt time_opts[lopt] += time.time() - t_opt
...@@ -1634,7 +1666,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1634,7 +1666,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
print >> stream, "%s%s %s id=%i" % ( print >> stream, "%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self)) (' ' * level), self.__class__.__name__, name, id(self))
if depth != 0: if depth != 0:
for lopt in self.local_optimizers: for lopt in self.get_local_optimizers():
lopt.print_summary(stream, level=(level + 2), lopt.print_summary(stream, level=(level + 2),
depth=(depth - 1)) depth=(depth - 1))
...@@ -1654,7 +1686,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1654,7 +1686,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
start_nb_nodes, end_nb_nodes, max_nb_nodes) start_nb_nodes, end_nb_nodes, max_nb_nodes)
print >> stream, blanc, " time io_toposort %.3fs" % sum( print >> stream, blanc, " time io_toposort %.3fs" % sum(
io_toposort_timing) io_toposort_timing)
s = sum([time_opts[o] for o in opt.local_optimizers]) s = sum([time_opts[o] for o in opt.get_local_optimizers()])
print >> stream, blanc, " time in local optimizers %.3fs" % s print >> stream, blanc, " time in local optimizers %.3fs" % s
s = sum([time_opts[o] for o in opt.global_optimizers]) s = sum([time_opts[o] for o in opt.global_optimizers])
print >> stream, blanc, " time in global optimizers %.3fs" % s print >> stream, blanc, " time in global optimizers %.3fs" % s
...@@ -1679,7 +1711,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1679,7 +1711,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
not_used = 0 not_used = 0
not_used_time = 0 not_used_time = 0
process_count = {} process_count = {}
for o in opt.global_optimizers + opt.local_optimizers: for o in opt.global_optimizers + opt.get_local_optimizers():
process_count.setdefault(o, 0) process_count.setdefault(o, 0)
for count in loop_process_count: for count in loop_process_count:
for o, v in count.iteritems(): for o, v in count.iteritems():
...@@ -1707,8 +1739,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1707,8 +1739,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
#(opt, loop_timing, loop_process_count, max_nb_nodes, #(opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1 # global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers = set(prof1[0].local_optimizers).union( local_optimizers = set(prof1[0].get_local_optimizers()).union(
prof2[0].local_optimizers) prof2[0].get_local_optimizers())
global_optimizers = set(prof1[0].global_optimizers).union( global_optimizers = set(prof1[0].global_optimizers).union(
prof2[0].global_optimizers) prof2[0].global_optimizers)
new_opt = EquilibriumOptimizer( new_opt = EquilibriumOptimizer(
......
...@@ -49,7 +49,7 @@ def info(*msg): ...@@ -49,7 +49,7 @@ def info(*msg):
_logger.info('INFO theano.scan: ' + ' '.join(msg)) _logger.info('INFO theano.scan: ' + ' '.join(msg))
@gof.local_optimizer([None]) @gof.local_optimizer([scan_op.Scan])
def remove_constants_and_unused_inputs_scan(node): def remove_constants_and_unused_inputs_scan(node):
''' '''
Move constants into the inner graph, and remove unused inputs. Move constants into the inner graph, and remove unused inputs.
...@@ -1337,7 +1337,7 @@ def make_equiv(lo, li): ...@@ -1337,7 +1337,7 @@ def make_equiv(lo, li):
return left, right return left, right
@gof.local_optimizer([None]) @gof.local_optimizer([scan_op.Scan])
def scan_merge_inouts(node): def scan_merge_inouts(node):
if not isinstance(node.op, scan_op.Scan): if not isinstance(node.op, scan_op.Scan):
return False return False
......
...@@ -1645,7 +1645,7 @@ class Dot22(GemmRelated): ...@@ -1645,7 +1645,7 @@ class Dot22(GemmRelated):
_dot22 = Dot22() _dot22 = Dot22()
@local_optimizer([T._dot]) @local_optimizer([T.Dot])
def local_dot_to_dot22(node): def local_dot_to_dot22(node):
# This works for tensor.outer too because basic.outer is a macro that # This works for tensor.outer too because basic.outer is a macro that
# produces a dot(dimshuffle,dimshuffle) of form 4 below # produces a dot(dimshuffle,dimshuffle) of form 4 below
...@@ -2025,7 +2025,7 @@ blas_optdb.register('local_dot22_to_dot22scalar', ...@@ -2025,7 +2025,7 @@ blas_optdb.register('local_dot22_to_dot22scalar',
#from opt import register_specialize, register_canonicalize #from opt import register_specialize, register_canonicalize
#@register_specialize #@register_specialize
@local_optimizer([]) @local_optimizer([T.sub, T.add])
def local_print_as_we_go_along(node): def local_print_as_we_go_along(node):
if node.op in (T.sub, T.add): if node.op in (T.sub, T.add):
debugprint(node) debugprint(node)
...@@ -589,7 +589,7 @@ opt.local_mul_canonizer.add_simplifier(softmax_simplifier, ...@@ -589,7 +589,7 @@ opt.local_mul_canonizer.add_simplifier(softmax_simplifier,
if 0: if 0:
@opt.register_specialize @opt.register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([tensor.add])
def local_softmax_grad(node): def local_softmax_grad(node):
'''dy*sm - DimShuffle{0,'x'}(sum{1}(dy*sm))*sm -> softmax_grad(dy,sm)''' '''dy*sm - DimShuffle{0,'x'}(sum{1}(dy*sm))*sm -> softmax_grad(dy,sm)'''
#TODO what if the signs are changed? #TODO what if the signs are changed?
...@@ -1417,7 +1417,7 @@ def _is_const(z, val, approx=False): ...@@ -1417,7 +1417,7 @@ def _is_const(z, val, approx=False):
@opt.register_specialize @opt.register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([subtensor.AdvancedSubtensor])
def local_advanced_indexing_crossentropy_onehot(node): def local_advanced_indexing_crossentropy_onehot(node):
log = None log = None
sm = None sm = None
......
...@@ -347,7 +347,7 @@ compile.optdb['canonicalize'].register( ...@@ -347,7 +347,7 @@ compile.optdb['canonicalize'].register(
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@gof.local_optimizer([None]) @gof.local_optimizer([T.Dot])
def local_0_dot_x(node): def local_0_dot_x(node):
if not isinstance(node.op, T.Dot): if not isinstance(node.op, T.Dot):
return False return False
...@@ -390,7 +390,7 @@ def local_0_dot_x(node): ...@@ -390,7 +390,7 @@ def local_0_dot_x(node):
###################### ######################
@gof.local_optimizer([None, None]) @gof.local_optimizer([DimShuffle])
def local_dimshuffle_lift(node): def local_dimshuffle_lift(node):
""" """
"Lifts" DimShuffle through Elemwise operations and merges "Lifts" DimShuffle through Elemwise operations and merges
...@@ -431,7 +431,7 @@ def local_dimshuffle_lift(node): ...@@ -431,7 +431,7 @@ def local_dimshuffle_lift(node):
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([T.DimShuffle])
def local_lift_transpose_through_dot(node): def local_lift_transpose_through_dot(node):
""" """
dot(x,y).T -> dot(y.T, x.T) dot(x,y).T -> dot(y.T, x.T)
...@@ -456,7 +456,7 @@ def local_lift_transpose_through_dot(node): ...@@ -456,7 +456,7 @@ def local_lift_transpose_through_dot(node):
return [T.dot(y.T, x.T)] return [T.dot(y.T, x.T)]
@gof.local_optimizer([]) @gof.local_optimizer([DimShuffle])
def dimshuffle_as_view(node): def dimshuffle_as_view(node):
op = node.op op = node.op
if not isinstance(op, DimShuffle) or op.inplace: if not isinstance(op, DimShuffle) or op.inplace:
...@@ -476,7 +476,7 @@ register_specialize(local_dimshuffle_lift) ...@@ -476,7 +476,7 @@ register_specialize(local_dimshuffle_lift)
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([T.DimShuffle])
def local_dimshuffle_no_inplace_at_canonicalize(node): def local_dimshuffle_no_inplace_at_canonicalize(node):
if isinstance(node.op, T.DimShuffle) and node.op.inplace: if isinstance(node.op, T.DimShuffle) and node.op.inplace:
return [T.DimShuffle(node.op.input_broadcastable, return [T.DimShuffle(node.op.input_broadcastable,
...@@ -1213,9 +1213,10 @@ def local_shape_to_shape_i(node): ...@@ -1213,9 +1213,10 @@ def local_shape_to_shape_i(node):
return [shape_feature.make_vector_shape(node.inputs[0])] return [shape_feature.make_vector_shape(node.inputs[0])]
# TODO: Not sure what type of node we are expecting here
@register_specialize @register_specialize
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T._shape]) @gof.local_optimizer(None)
def local_track_shape_i(node): def local_track_shape_i(node):
try: try:
shape_feature = node.fgraph.shape_feature shape_feature = node.fgraph.shape_feature
...@@ -1415,7 +1416,7 @@ def local_remove_useless_assert(node): ...@@ -1415,7 +1416,7 @@ def local_remove_useless_assert(node):
return [assert_(node.inputs[0], *cond)] return [assert_(node.inputs[0], *cond)]
@gof.local_optimizer([T.Alloc]) @gof.local_optimizer([T.Elemwise])
def local_alloc_elemwise(node): def local_alloc_elemwise(node):
""" """
elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION)) elemwise(alloc(x, shp), ..., y.TensorType(BROADCAST CONDITION))
...@@ -1534,7 +1535,7 @@ else: ...@@ -1534,7 +1535,7 @@ else:
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([T.Elemwise])
def local_upcast_elemwise_constant_inputs(node): def local_upcast_elemwise_constant_inputs(node):
"""This explicitly upcasts constant inputs to elemwise Ops, when """This explicitly upcasts constant inputs to elemwise Ops, when
those Ops do implicit upcasting anyway. those Ops do implicit upcasting anyway.
...@@ -1682,7 +1683,7 @@ def local_useless_subtensor(node): ...@@ -1682,7 +1683,7 @@ def local_useless_subtensor(node):
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([Subtensor])
def local_subtensor_lift(node): def local_subtensor_lift(node):
""" """
unary(x)[idx] -> unary(x[idx])#any broadcast pattern. unary(x)[idx] -> unary(x[idx])#any broadcast pattern.
...@@ -1892,7 +1893,7 @@ def merge_two_slices(slice1, len1, slice2, len2): ...@@ -1892,7 +1893,7 @@ def merge_two_slices(slice1, len1, slice2, len2):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([Subtensor])
def local_subtensor_merge(node): def local_subtensor_merge(node):
""" """
Refactored optimization to deal with all cases of tensor merging. Refactored optimization to deal with all cases of tensor merging.
...@@ -1954,7 +1955,7 @@ def local_subtensor_merge(node): ...@@ -1954,7 +1955,7 @@ def local_subtensor_merge(node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([Subtensor])
def local_subtensor_of_alloc(node): def local_subtensor_of_alloc(node):
"""alloc[x:y] -> alloc""" """alloc[x:y] -> alloc"""
if not isinstance(node.op, Subtensor): if not isinstance(node.op, Subtensor):
...@@ -2007,7 +2008,7 @@ def local_subtensor_of_alloc(node): ...@@ -2007,7 +2008,7 @@ def local_subtensor_of_alloc(node):
@register_canonicalize @register_canonicalize
@gof.local_optimizer([None]) @gof.local_optimizer([T.add])
def local_IncSubtensor_serialize(node): def local_IncSubtensor_serialize(node):
""" """
When using Subtensor, gradient graphs can be ugly. When using Subtensor, gradient graphs can be ugly.
...@@ -2079,7 +2080,7 @@ compile.optdb.register('pre_local_IncSubtensor_serialize', ...@@ -2079,7 +2080,7 @@ compile.optdb.register('pre_local_IncSubtensor_serialize',
#after priority 50 Destructive inplace operations #after priority 50 Destructive inplace operations
#gemm is the first one now, at priority 70 #gemm is the first one now, at priority 70
@gof.local_optimizer([None]) @gof.local_optimizer([IncSubtensor]) # XXX: GPU
def local_inplace_setsubtensor(node): def local_inplace_setsubtensor(node):
""" """
Also work for GpuIncSubtensor Also work for GpuIncSubtensor
...@@ -2098,7 +2099,7 @@ compile.optdb.register('local_inplace_setsubtensor', ...@@ -2098,7 +2099,7 @@ compile.optdb.register('local_inplace_setsubtensor',
'fast_run', 'inplace') # DEBUG 'fast_run', 'inplace') # DEBUG
@gof.local_optimizer([None]) @gof.local_optimizer([AdvancedIncSubtensor1]) # XXX: GPU
def local_inplace_incsubtensor1(node): def local_inplace_incsubtensor1(node):
""" also work for GpuAdvancedIncSubtensor1 """ """ also work for GpuAdvancedIncSubtensor1 """
if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace: if isinstance(node.op, AdvancedIncSubtensor1) and not node.op.inplace:
...@@ -2116,7 +2117,7 @@ compile.optdb.register('local_inplace_incsubtensor1', ...@@ -2116,7 +2117,7 @@ compile.optdb.register('local_inplace_incsubtensor1',
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@gof.local_optimizer([None]) @gof.local_optimizer([IncSubtensor])
def local_incsubtensor_of_allocs(node): def local_incsubtensor_of_allocs(node):
""" """
IncSubtensor(x, zeros, idx) -> x IncSubtensor(x, zeros, idx) -> x
...@@ -2139,7 +2140,7 @@ def local_incsubtensor_of_allocs(node): ...@@ -2139,7 +2140,7 @@ def local_incsubtensor_of_allocs(node):
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@gof.local_optimizer([None]) @gof.local_optimizer([IncSubtensor])
def local_setsubtensor_of_allocs(node): def local_setsubtensor_of_allocs(node):
""" """
SetSubtensor(x, x[idx], idx) -> x SetSubtensor(x, x[idx], idx) -> x
...@@ -2286,7 +2287,7 @@ def local_join_1(node): ...@@ -2286,7 +2287,7 @@ def local_join_1(node):
############### ###############
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([T.Elemwise])
def local_remove_switch_const_cond(node): def local_remove_switch_const_cond(node):
""" """
This optimization makes the following changes in the graph: This optimization makes the following changes in the graph:
...@@ -2369,7 +2370,7 @@ def local_mul_switch_sink(node): ...@@ -2369,7 +2370,7 @@ def local_mul_switch_sink(node):
@register_canonicalize @register_canonicalize
@gof.local_optimizer([T.true_div]) @gof.local_optimizer([T.true_div, T.int_div, T.floor_div])
def local_div_switch_sink(node): def local_div_switch_sink(node):
""" """
This optimization makes the folowing changes in the graph: This optimization makes the folowing changes in the graph:
...@@ -2413,7 +2414,7 @@ def local_div_switch_sink(node): ...@@ -2413,7 +2414,7 @@ def local_div_switch_sink(node):
################ ################
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@gof.local_optimizer([]) @gof.local_optimizer([T.Flatten])
def local_flatten_lift(node): def local_flatten_lift(node):
""" """
Flatten(UnaryElemwise(x)) -> UnaryElemwise(Flatten(x)) Flatten(UnaryElemwise(x)) -> UnaryElemwise(Flatten(x))
...@@ -2434,7 +2435,7 @@ def local_flatten_lift(node): ...@@ -2434,7 +2435,7 @@ def local_flatten_lift(node):
################## ##################
@gof.local_optimizer([None, None]) @gof.local_optimizer([T.Reshape])
def local_reshape_chain(node): def local_reshape_chain(node):
""" """
Reshape(Reshape(shape1),shape2) -> Reshape(shape2) Reshape(Reshape(shape1),shape2) -> Reshape(shape2)
...@@ -2462,7 +2463,7 @@ register_canonicalize(local_reshape_chain) ...@@ -2462,7 +2463,7 @@ register_canonicalize(local_reshape_chain)
@register_canonicalize @register_canonicalize
@register_stabilize @register_stabilize
@gof.local_optimizer([]) @gof.local_optimizer([T.Reshape])
def local_reshape_lift(node): def local_reshape_lift(node):
""" """
Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x)) Reshape(UnaryElemwise(x)) -> UnaryElemwise(Reshape(x))
...@@ -2482,7 +2483,7 @@ def local_reshape_lift(node): ...@@ -2482,7 +2483,7 @@ def local_reshape_lift(node):
if 0: if 0:
# TODO: Test that this optimziation works. # TODO: Test that this optimziation works.
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([T.Reshape])
def local_scalar_reshape(node): def local_scalar_reshape(node):
"""Eliminate reshape Ops whose inputs and outputs are scalars """ """Eliminate reshape Ops whose inputs and outputs are scalars """
if isinstance(node.op, T.Reshape): if isinstance(node.op, T.Reshape):
...@@ -2498,7 +2499,7 @@ if 0: ...@@ -2498,7 +2499,7 @@ if 0:
# TODO: Remember to take into account the new sum dtype argument if this # TODO: Remember to take into account the new sum dtype argument if this
# optimization is enabled. # optimization is enabled.
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([T.Sum])
def local_sum_over_empty(node): def local_sum_over_empty(node):
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum):
# This optimization needs ShapeOpt and fgraph.shape_feature # This optimization needs ShapeOpt and fgraph.shape_feature
...@@ -2520,7 +2521,7 @@ if 0: ...@@ -2520,7 +2521,7 @@ if 0:
################## ##################
@gof.local_optimizer([None, T.fill]) @gof.local_optimizer([T.Elemwise])
def local_fill_cut(node): def local_fill_cut(node):
""" """
f(fill(a,b), c) -> f(b, c) f(fill(a,b), c) -> f(b, c)
...@@ -2574,7 +2575,7 @@ register_canonicalize(local_fill_cut) ...@@ -2574,7 +2575,7 @@ register_canonicalize(local_fill_cut)
register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy') register_canonicalize(gof.OpRemove(T.tensor_copy), name='remove_tensor_copy')
@gof.local_optimizer([None, T.fill]) @gof.local_optimizer([T.Elemwise])
def local_fill_sink(node): def local_fill_sink(node):
""" """
f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e))) f(fill(a, b), fill(c, d), e) -> fill(a, fill(c, f(b, d, e)))
...@@ -2662,8 +2663,7 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2662,8 +2663,7 @@ class Canonizer(gof.LocalOptimizer):
self.external_simplifiers.append((reason, simplifier)) self.external_simplifiers.append((reason, simplifier))
def tracks(self): def tracks(self):
return [[self.main, None], [self.inverse, None], return [self.main, self.inverse, self.reciprocal]
[self.reciprocal, None]]
def get_num_denum(self, input): def get_num_denum(self, input):
""" """
...@@ -3051,7 +3051,7 @@ register_canonicalize(local_neg_to_mul) ...@@ -3051,7 +3051,7 @@ register_canonicalize(local_neg_to_mul)
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([T.Sum])
def local_sum_mul_by_scalar(node): def local_sum_mul_by_scalar(node):
"""sum(scalar * smth) -> scalar * sum(smth) """sum(scalar * smth) -> scalar * sum(smth)
sum(-smth) -> -sum(smth) sum(-smth) -> -sum(smth)
...@@ -3088,7 +3088,7 @@ def local_sum_mul_by_scalar(node): ...@@ -3088,7 +3088,7 @@ def local_sum_mul_by_scalar(node):
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([T.Elemwise])
def local_elemwise_sub_zeros(node): def local_elemwise_sub_zeros(node):
""" """
Elemwise{sub}(X,X) -> zeros_like(X) Elemwise{sub}(X,X) -> zeros_like(X)
...@@ -3102,7 +3102,7 @@ def local_elemwise_sub_zeros(node): ...@@ -3102,7 +3102,7 @@ def local_elemwise_sub_zeros(node):
@register_canonicalize @register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([T.Sum])
def local_sum_div_dimshuffle(node): def local_sum_div_dimshuffle(node):
'''sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b, '''sum(a / dimshuffle{...}(b), axis=l) -> sum(a, axis={...}) / b,
if dimension l of the DimShuffle is 'x'.''' if dimension l of the DimShuffle is 'x'.'''
...@@ -3191,7 +3191,7 @@ def local_sum_div_dimshuffle(node): ...@@ -3191,7 +3191,7 @@ def local_sum_div_dimshuffle(node):
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([T.Sum])
def local_sum_all_to_none(node): def local_sum_all_to_none(node):
"""Sum{0,1,...N} -> Sum{}""" """Sum{0,1,...N} -> Sum{}"""
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum):
...@@ -3204,7 +3204,7 @@ def local_sum_all_to_none(node): ...@@ -3204,7 +3204,7 @@ def local_sum_all_to_none(node):
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([T.Sum])
def local_sum_sum(node): def local_sum_sum(node):
""" """
Sum(Sum()) -> Sum Sum(Sum()) -> Sum
...@@ -3272,7 +3272,7 @@ def local_sum_sum(node): ...@@ -3272,7 +3272,7 @@ def local_sum_sum(node):
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([T.CAReduce])
def local_cut_useless_reduce(node): def local_cut_useless_reduce(node):
"""Sum(a, axis=[]) -> a """ """Sum(a, axis=[]) -> a """
if isinstance(node.op, T.CAReduce): if isinstance(node.op, T.CAReduce):
...@@ -3288,7 +3288,7 @@ def local_cut_useless_reduce(node): ...@@ -3288,7 +3288,7 @@ def local_cut_useless_reduce(node):
# #
#@register_canonicalize #@register_canonicalize
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([T.CAReduce])
def local_reduce_broadcastable(node): def local_reduce_broadcastable(node):
"""Remove reduction over broadcastable dimensions""" """Remove reduction over broadcastable dimensions"""
if isinstance(node.op, T.CAReduce): if isinstance(node.op, T.CAReduce):
...@@ -3327,7 +3327,7 @@ def local_reduce_broadcastable(node): ...@@ -3327,7 +3327,7 @@ def local_reduce_broadcastable(node):
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([T.Sum])
def local_sum_alloc(node): def local_sum_alloc(node):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes)""" """ sum(alloc(constant,shapes...)) => constant*prod(shapes)"""
if isinstance(node.op, T.Sum): if isinstance(node.op, T.Sum):
...@@ -3734,7 +3734,7 @@ def local_abs_lift(node): ...@@ -3734,7 +3734,7 @@ def local_abs_lift(node):
@register_specialize @register_specialize
@gof.local_optimizer([]) @gof.local_optimizer([T.mul])
def local_abs_merge(node): def local_abs_merge(node):
""" """
merge abs generated by local_abs_lift when the canonizer don't merge abs generated by local_abs_lift when the canonizer don't
...@@ -3909,8 +3909,7 @@ def attempt_distribution(factor, num, denum): ...@@ -3909,8 +3909,7 @@ def attempt_distribution(factor, num, denum):
neg_pairs))), num, denum neg_pairs))), num, denum
@gof.local_optimizer([T.mul, T.add, T.mul], [T.mul, T.sub, T.mul], @gof.local_optimizer([T.mul])
[T.mul, T.add, T.true_div], [T.mul, T.sub, T.true_div])
def local_greedy_distributor(node): def local_greedy_distributor(node):
""" """
This optimization tries to apply distributivity of multiplication This optimization tries to apply distributivity of multiplication
...@@ -3976,7 +3975,7 @@ register_canonicalize(local_greedy_distributor) ...@@ -3976,7 +3975,7 @@ register_canonicalize(local_greedy_distributor)
register_stabilize(local_greedy_distributor) register_stabilize(local_greedy_distributor)
@gof.local_optimizer([None]) @gof.local_optimizer(None)
def constant_folding(node): def constant_folding(node):
for input in node.inputs: for input in node.inputs:
if not isinstance(input, Constant): if not isinstance(input, Constant):
......
...@@ -816,7 +816,7 @@ def multinomial(random_state, size=None, n=1, pvals=[0.5, 0.5], ...@@ -816,7 +816,7 @@ def multinomial(random_state, size=None, n=1, pvals=[0.5, 0.5],
return op(random_state, size, n, pvals) return op(random_state, size, n, pvals)
@gof.local_optimizer([None]) @gof.local_optimizer([RandomFunction])
def random_make_inplace(node): def random_make_inplace(node):
op = node.op op = node.op
if isinstance(op, RandomFunction) and not op.inplace: if isinstance(op, RandomFunction) and not op.inplace:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论