提交 0b528507 authored 作者: Frederic's avatar Frederic

some pep8

上级 8e80196b
...@@ -514,7 +514,8 @@ class MergeFeature(object): ...@@ -514,7 +514,8 @@ class MergeFeature(object):
continue continue
inputs_match = all(node_in is cand_in inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(node.inputs, candidate.inputs)) for node_in, cand_in in zip(node.inputs,
candidate.inputs))
if inputs_match and node.op == candidate.op: if inputs_match and node.op == candidate.op:
if (node, candidate) in self.blacklist: if (node, candidate) in self.blacklist:
# They were already tried, and there was an error # They were already tried, and there was an error
...@@ -751,7 +752,7 @@ class LocalOptimizer(object): ...@@ -751,7 +752,7 @@ class LocalOptimizer(object):
""" """
raise utils.MethodNotDefined("transform", raise utils.MethodNotDefined("transform",
type(self), self.__class__.__name__) type(self), self.__class__.__name__)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
""" """
...@@ -781,7 +782,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -781,7 +782,7 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def __str__(self): def __str__(self):
return getattr(self, '__name__', return getattr(self, '__name__',
'<FromFunctionLocalOptimizer instance>') '<FromFunctionLocalOptimizer instance>')
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print >> stream, "%s%s id=%i" % ( print >> stream, "%s%s id=%i" % (
...@@ -811,8 +812,8 @@ class LocalOptGroup(LocalOptimizer): ...@@ -811,8 +812,8 @@ class LocalOptGroup(LocalOptimizer):
def __str__(self): def __str__(self):
return getattr(self, '__name__', return getattr(self, '__name__',
('<theano.gof.opt.LocalOptGroup instance>' ('<theano.gof.opt.LocalOptGroup instance>' +
+ str([str(o) for o in self.opts]))) str([str(o) for o in self.opts])))
def transform(self, node): def transform(self, node):
for opt in self.opts: for opt in self.opts:
...@@ -968,7 +969,7 @@ class PatternSub(LocalOptimizer): ...@@ -968,7 +969,7 @@ class PatternSub(LocalOptimizer):
""" """
def __init__(self, in_pattern, out_pattern, allow_multiple_clients=False, def __init__(self, in_pattern, out_pattern, allow_multiple_clients=False,
skip_identities_fn=None, name=None, pdb=False): skip_identities_fn=None, name=None, pdb=False):
""" """
Creates a PatternSub that replaces occurrences of Creates a PatternSub that replaces occurrences of
in_pattern by occurrences of out_pattern. in_pattern by occurrences of out_pattern.
...@@ -989,10 +990,10 @@ class PatternSub(LocalOptimizer): ...@@ -989,10 +990,10 @@ class PatternSub(LocalOptimizer):
self.op = self.in_pattern['pattern'][0] self.op = self.in_pattern['pattern'][0]
else: else:
raise TypeError("The pattern to search for must start with " raise TypeError("The pattern to search for must start with "
"a specific Op instance.") "a specific Op instance.")
self.__doc__ = (self.__class__.__doc__ self.__doc__ = (self.__class__.__doc__ +
+ "\n\nThis instance does: " "\n\nThis instance does: " +
+ str(self) + "\n") str(self) + "\n")
self.allow_multiple_clients = allow_multiple_clients self.allow_multiple_clients = allow_multiple_clients
self.skip_identities_fn = skip_identities_fn self.skip_identities_fn = skip_identities_fn
if name: if name:
...@@ -1035,7 +1036,7 @@ class PatternSub(LocalOptimizer): ...@@ -1035,7 +1036,7 @@ class PatternSub(LocalOptimizer):
#TODO: Not sure how to handle multiple_clients flag #TODO: Not sure how to handle multiple_clients flag
###print 'retrying match', pattern, expr_equiv ###print 'retrying match', pattern, expr_equiv
return match(pattern, expr_equiv, u, return match(pattern, expr_equiv, u,
allow_multiple_clients=allow_multiple_clients) allow_multiple_clients=allow_multiple_clients)
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
if expr.owner is None: if expr.owner is None:
...@@ -1055,8 +1056,8 @@ class PatternSub(LocalOptimizer): ...@@ -1055,8 +1056,8 @@ class PatternSub(LocalOptimizer):
real_pattern = pattern['pattern'] real_pattern = pattern['pattern']
except KeyError: except KeyError:
raise KeyError( raise KeyError(
"Malformed pattern: %s (expected key 'pattern')" "Malformed pattern: %s (expected key 'pattern')"
% pattern) % pattern)
constraint = pattern.get('constraint', lambda expr: True) constraint = pattern.get('constraint', lambda expr: True)
if constraint(expr): if constraint(expr):
return match(real_pattern, expr, u, return match(real_pattern, expr, u,
...@@ -1286,7 +1287,8 @@ class NavigatorOptimizer(Optimizer): ...@@ -1286,7 +1287,8 @@ class NavigatorOptimizer(Optimizer):
except Exception, e: except Exception, e:
if self.failure_callback is not None: if self.failure_callback is not None:
self.failure_callback(e, self, self.failure_callback(e, self,
[(x, None) for x in node.outputs], lopt) [(x, None) for x in node.outputs],
lopt)
return False return False
else: else:
raise raise
...@@ -1294,10 +1296,10 @@ class NavigatorOptimizer(Optimizer): ...@@ -1294,10 +1296,10 @@ class NavigatorOptimizer(Optimizer):
return False return False
if not isinstance(replacements, (tuple, list)): if not isinstance(replacements, (tuple, list)):
raise TypeError('Optimizer %s gave wrong type of replacement. ' raise TypeError('Optimizer %s gave wrong type of replacement. '
'Expected list or tuple.' % lopt) 'Expected list or tuple.' % lopt)
if len(node.outputs) != len(replacements): if len(node.outputs) != len(replacements):
raise ValueError('Optimizer %s gave wrong number of replacements' raise ValueError('Optimizer %s gave wrong number of replacements'
% lopt) % lopt)
# None in the replacement mean that this variable isn't used # None in the replacement mean that this variable isn't used
# and we want to remove it # and we want to remove it
for r, rnew in zip(node.outputs, replacements): for r, rnew in zip(node.outputs, replacements):
...@@ -1336,19 +1338,19 @@ class NavigatorOptimizer(Optimizer): ...@@ -1336,19 +1338,19 @@ class NavigatorOptimizer(Optimizer):
(' ' * level), self.__class__.__name__, id(self)) (' ' * level), self.__class__.__name__, id(self))
if depth != 0: if depth != 0:
self.local_opt.print_summary(stream, level=(level + 2), self.local_opt.print_summary(stream, level=(level + 2),
depth=(depth - 1)) depth=(depth - 1))
class TopoOptimizer(NavigatorOptimizer): class TopoOptimizer(NavigatorOptimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, local_opt, order='in_to_out', ignore_newtrees=False, def __init__(self, local_opt, order='in_to_out', ignore_newtrees=False,
failure_callback=None): failure_callback=None):
if order not in ['out_to_in', 'in_to_out']: if order not in ['out_to_in', 'in_to_out']:
raise ValueError("order must be 'out_to_in' or 'in_to_out'") raise ValueError("order must be 'out_to_in' or 'in_to_out'")
self.order = order self.order = order
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees,
failure_callback) failure_callback)
def apply(self, fgraph, start_from=None): def apply(self, fgraph, start_from=None):
if start_from is None: if start_from is None:
...@@ -1414,12 +1416,12 @@ class OpKeyOptimizer(NavigatorOptimizer): ...@@ -1414,12 +1416,12 @@ class OpKeyOptimizer(NavigatorOptimizer):
"""WRITEME""" """WRITEME"""
def __init__(self, local_opt, ignore_newtrees=False, def __init__(self, local_opt, ignore_newtrees=False,
failure_callback=None): failure_callback=None):
if not hasattr(local_opt, 'op_key'): if not hasattr(local_opt, 'op_key'):
raise TypeError("LocalOptimizer for OpKeyOptimizer must have " raise TypeError("LocalOptimizer for OpKeyOptimizer must have "
"an 'op_key' method.") "an 'op_key' method.")
NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees, NavigatorOptimizer.__init__(self, local_opt, ignore_newtrees,
failure_callback) failure_callback)
def apply(self, fgraph): def apply(self, fgraph):
op = self.local_opt.op_key() op = self.local_opt.op_key()
...@@ -1623,7 +1625,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1623,7 +1625,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
+ "%f with the theano flag 'optdb.max_use_ratio'." % + "%f with the theano flag 'optdb.max_use_ratio'." %
config.optdb.max_use_ratio) config.optdb.max_use_ratio)
return (self, loop_timing, loop_process_count, (start_nb_nodes, end_nb_nodes, max_nb_nodes), return (self, loop_timing, loop_process_count,
(start_nb_nodes, end_nb_nodes, max_nb_nodes),
global_opt_timing, nb_nodes, time_opts, io_toposort_timing) global_opt_timing, nb_nodes, time_opts, io_toposort_timing)
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
...@@ -1633,11 +1636,12 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1633,11 +1636,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if depth != 0: if depth != 0:
for lopt in self.local_optimizers: for lopt in self.local_optimizers:
lopt.print_summary(stream, level=(level + 2), lopt.print_summary(stream, level=(level + 2),
depth=(depth - 1)) depth=(depth - 1))
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
(opt, loop_timing, loop_process_count, (start_nb_nodes, end_nb_nodes, max_nb_nodes), (opt, loop_timing, loop_process_count,
(start_nb_nodes, end_nb_nodes, max_nb_nodes),
global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof
blanc = (' ' * level) blanc = (' ' * level)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论