提交 b129fb77 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 theano/gof/opt.py; 2 E left

上级 de08b763
...@@ -29,6 +29,7 @@ from . import destroyhandler as dh ...@@ -29,6 +29,7 @@ from . import destroyhandler as dh
_logger = logging.getLogger('theano.gof.opt') _logger = logging.getLogger('theano.gof.opt')
_optimizer_idx = [0] _optimizer_idx = [0]
def _list_of_nodes(fgraph): def _list_of_nodes(fgraph):
return list(graph.io_toposort(fgraph.inputs, fgraph.outputs)) return list(graph.io_toposort(fgraph.inputs, fgraph.outputs))
...@@ -99,7 +100,7 @@ class Optimizer(object): ...@@ -99,7 +100,7 @@ class Optimizer(object):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print("%s%s %s id=%i" % ( print("%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self)), file=stream) (' ' * level), self.__class__.__name__, name, id(self)), file=stream)
@staticmethod @staticmethod
def print_profile(stream, prof, level=0): def print_profile(stream, prof, level=0):
...@@ -121,9 +122,9 @@ class FromFunctionOptimizer(Optimizer): ...@@ -121,9 +122,9 @@ class FromFunctionOptimizer(Optimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % ( print("%s%s id=%i" % (
' ' * level, ' ' * level,
str(self.apply), str(self.apply),
id(self)), file=stream) id(self)), file=stream)
def __call__(self, *args, **kwargs): def __call__(self, *args, **kwargs):
return self.fn(*args, **kwargs) return self.fn(*args, **kwargs)
...@@ -222,7 +223,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -222,7 +223,7 @@ class SeqOptimizer(Optimizer, list):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print("%s%s %s id=%i" % ( print("%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self)), file=stream) (' ' * level), self.__class__.__name__, name, id(self)), file=stream)
# This way, -1 will do all depth # This way, -1 will do all depth
if depth != 0: if depth != 0:
depth -= 1 depth -= 1
...@@ -241,8 +242,8 @@ class SeqOptimizer(Optimizer, list): ...@@ -241,8 +242,8 @@ class SeqOptimizer(Optimizer, list):
elif hasattr(opts, "__name__"): elif hasattr(opts, "__name__"):
print(blanc, opts.__name__, end=' ', file=stream) print(blanc, opts.__name__, end=' ', file=stream)
print((" time %.3fs for %d/%d nodes" print((" time %.3fs for %d/%d nodes"
" before/after optimization" % ( " before/after optimization" % (
sum(prof), nb_node_before, nb_node_after)), file=stream) sum(prof), nb_node_before, nb_node_after)), file=stream)
print(blanc, " %.3fs for fgraph.validate()" % (validate_time), file=stream) print(blanc, " %.3fs for fgraph.validate()" % (validate_time), file=stream)
print(blanc, " %.3fs for callback" % (callback_time), file=stream) print(blanc, " %.3fs for callback" % (callback_time), file=stream)
if level == 0: if level == 0:
...@@ -324,7 +325,7 @@ class SeqOptimizer(Optimizer, list): ...@@ -324,7 +325,7 @@ class SeqOptimizer(Optimizer, list):
new_t[idx] += p[1][p[0].index(l)] new_t[idx] += p[1][p[0].index(l)]
if hasattr(l, 'merge_profile'): if hasattr(l, 'merge_profile'):
assert len(p[6][p[0].index(l)]) == \ assert len(p[6][p[0].index(l)]) == \
len(new_sub_profile[idx]) len(new_sub_profile[idx])
new_sub_profile[idx] = l.merge_profile( new_sub_profile[idx] = l.merge_profile(
new_sub_profile[idx], p[6][p[0].index(l)]) new_sub_profile[idx], p[6][p[0].index(l)])
else: else:
...@@ -729,6 +730,7 @@ def pre_constant_merge(vars): ...@@ -729,6 +730,7 @@ def pre_constant_merge(vars):
const_sig_inv = {} const_sig_inv = {}
if isinstance(vars, graph.Variable): if isinstance(vars, graph.Variable):
vars = [vars] vars = [vars]
def recursive_merge(var): def recursive_merge(var):
if var in seen_var: if var in seen_var:
return var return var
...@@ -761,7 +763,7 @@ def pre_constant_merge(vars): ...@@ -761,7 +763,7 @@ def pre_constant_merge(vars):
######################## ########################
### Local Optimizers ### # Local Optimizers #
######################## ########################
class LocalOptimizer(object): class LocalOptimizer(object):
...@@ -817,12 +819,14 @@ class LocalOptimizer(object): ...@@ -817,12 +819,14 @@ class LocalOptimizer(object):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % ( print("%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self)), file=stream) (' ' * level), self.__class__.__name__, id(self)), file=stream)
theano.configparser.AddConfigVar('metaopt.verbose', theano.configparser.AddConfigVar(
"Enable verbose output for meta optimizers", 'metaopt.verbose',
theano.configparser.BoolParam(False), in_c_key=False) "Enable verbose output for meta optimizers",
theano.configparser.BoolParam(False),
in_c_key=False)
class LocalMetaOptimizer(LocalOptimizer): class LocalMetaOptimizer(LocalOptimizer):
...@@ -933,9 +937,9 @@ class FromFunctionLocalOptimizer(LocalOptimizer): ...@@ -933,9 +937,9 @@ class FromFunctionLocalOptimizer(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % ( print("%s%s id=%i" % (
' ' * level, ' ' * level,
str(self.transform), str(self.transform),
id(self)), file=stream) id(self)), file=stream)
def local_optimizer(tracks, inplace=False): def local_optimizer(tracks, inplace=False):
...@@ -992,7 +996,7 @@ class LocalOptGroup(LocalOptimizer): ...@@ -992,7 +996,7 @@ class LocalOptGroup(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s id=%i" % ( print("%s%s id=%i" % (
(' ' * level), self.__class__.__name__, id(self)), file=stream) (' ' * level), self.__class__.__name__, id(self)), file=stream)
if depth != 0: if depth != 0:
depth -= 1 depth -= 1
for lopt in self.opts: for lopt in self.opts:
...@@ -1086,10 +1090,10 @@ class OpRemove(LocalOptimizer): ...@@ -1086,10 +1090,10 @@ class OpRemove(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s(%s) id=%i" % ( print("%s%s(%s) id=%i" % (
' ' * level, ' ' * level,
self.__class__.__name__, self.__class__.__name__,
str(self.op), str(self.op),
id(self)), file=stream) id(self)), file=stream)
class PatternSub(LocalOptimizer): class PatternSub(LocalOptimizer):
...@@ -1217,6 +1221,7 @@ class PatternSub(LocalOptimizer): ...@@ -1217,6 +1221,7 @@ class PatternSub(LocalOptimizer):
if node.op != self.op: if node.op != self.op:
return False return False
# TODO: if we remove pdb, do this speed things up? # TODO: if we remove pdb, do this speed things up?
def match(pattern, expr, u, allow_multiple_clients=False, pdb=False): def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
# TODO move outside match # TODO move outside match
def retry_with_equiv(): def retry_with_equiv():
...@@ -1233,9 +1238,8 @@ class PatternSub(LocalOptimizer): ...@@ -1233,9 +1238,8 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
if expr.owner is None: if expr.owner is None:
return False return False
if (not (expr.owner.op == pattern[0]) if (not (expr.owner.op == pattern[0]) or
or (not allow_multiple_clients (not allow_multiple_clients and len(expr.clients) > 1)):
and len(expr.clients) > 1)):
return retry_with_equiv() return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs): if len(pattern) - 1 != len(expr.owner.inputs):
return retry_with_equiv() return retry_with_equiv()
...@@ -1263,16 +1267,16 @@ class PatternSub(LocalOptimizer): ...@@ -1263,16 +1267,16 @@ class PatternSub(LocalOptimizer):
return retry_with_equiv() return retry_with_equiv()
else: else:
u = u.merge(expr, v) u = u.merge(expr, v)
elif (isinstance(pattern, (int, float)) elif (isinstance(pattern, (int, float)) and
and isinstance(expr, graph.Constant)): isinstance(expr, graph.Constant)):
if numpy.all( if numpy.all(
theano.tensor.constant(pattern).value == expr.value): theano.tensor.constant(pattern).value == expr.value):
return u return u
else: else:
return retry_with_equiv() return retry_with_equiv()
elif (isinstance(pattern, graph.Constant) elif (isinstance(pattern, graph.Constant) and
and isinstance(expr, graph.Constant) isinstance(expr, graph.Constant) and
and pattern.equals(expr)): pattern.equals(expr)):
return u return u
else: else:
return retry_with_equiv() return retry_with_equiv()
...@@ -1308,17 +1312,17 @@ class PatternSub(LocalOptimizer): ...@@ -1308,17 +1312,17 @@ class PatternSub(LocalOptimizer):
def pattern_to_str(pattern): def pattern_to_str(pattern):
if isinstance(pattern, (list, tuple)): if isinstance(pattern, (list, tuple)):
return "%s(%s)" % ( return "%s(%s)" % (
str(pattern[0]), str(pattern[0]),
", ".join([pattern_to_str(p) for p in pattern[1:]])) ", ".join([pattern_to_str(p) for p in pattern[1:]]))
elif isinstance(pattern, dict): elif isinstance(pattern, dict):
return "%s subject to %s" % ( return "%s subject to %s" % (
pattern_to_str(pattern['pattern']), pattern_to_str(pattern['pattern']),
str(pattern.get('constraint', 'no conditions'))) str(pattern.get('constraint', 'no conditions')))
else: else:
return str(pattern) return str(pattern)
return "%s -> %s" % ( return "%s -> %s" % (
pattern_to_str(self.in_pattern), pattern_to_str(self.in_pattern),
pattern_to_str(self.out_pattern)) pattern_to_str(self.out_pattern))
def __repr__(self): def __repr__(self):
return str(self) return str(self)
...@@ -1326,16 +1330,16 @@ class PatternSub(LocalOptimizer): ...@@ -1326,16 +1330,16 @@ class PatternSub(LocalOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, '__name__', getattr(self, 'name', None)) name = getattr(self, '__name__', getattr(self, 'name', None))
print("%s%s %s(%s, %s) id=%i" % ( print("%s%s %s(%s, %s) id=%i" % (
' ' * level, ' ' * level,
self.__class__.__name__, self.__class__.__name__,
name, name,
str(self.in_pattern), str(self.in_pattern),
str(self.out_pattern), str(self.out_pattern),
id(self)), file=stream) id(self)), file=stream)
################## ##################
### Navigators ### # Navigators #
################## ##################
# Use the following classes to apply LocalOptimizers # Use the following classes to apply LocalOptimizers
...@@ -1545,7 +1549,7 @@ class NavigatorOptimizer(Optimizer): ...@@ -1545,7 +1549,7 @@ class NavigatorOptimizer(Optimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
print("%s%s (%i)" % ( print("%s%s (%i)" % (
(' ' * level), self.__class__.__name__, id(self)), file=stream) (' ' * level), self.__class__.__name__, id(self)), file=stream)
if depth != 0: if depth != 0:
self.local_opt.print_summary(stream, level=(level + 2), self.local_opt.print_summary(stream, level=(level + 2),
depth=(depth - 1)) depth=(depth - 1))
...@@ -1734,7 +1738,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1734,7 +1738,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
self.final_optimizers = final_optimizers self.final_optimizers = final_optimizers
self.max_use_ratio = max_use_ratio self.max_use_ratio = max_use_ratio
assert self.max_use_ratio is not None, ( assert self.max_use_ratio is not None, (
'max_use_ratio has to be a number') 'max_use_ratio has to be a number')
def get_local_optimizers(self): def get_local_optimizers(self):
for opt in self.local_optimizers_all: for opt in self.local_optimizers_all:
...@@ -1811,8 +1815,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1811,8 +1815,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[gopt] += change_tracker.nb_imported - nb node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use: if global_process_count[gopt] > max_use:
max_use_abort = True max_use_abort = True
opt_name = (getattr(gopt, "name", None) opt_name = (getattr(gopt, "name", None) or
or getattr(gopt, "__name__", "")) getattr(gopt, "__name__", ""))
global_sub_profs.append(sub_profs) global_sub_profs.append(sub_profs)
global_opt_timing.append(float(time.time() - t0)) global_opt_timing.append(float(time.time() - t0))
...@@ -1858,8 +1862,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1858,8 +1862,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[lopt] += change_tracker.nb_imported - nb node_created[lopt] += change_tracker.nb_imported - nb
if global_process_count[lopt] > max_use: if global_process_count[lopt] > max_use:
max_use_abort = True max_use_abort = True
opt_name = (getattr(lopt, "name", None) opt_name = (getattr(lopt, "name", None) or
or getattr(lopt, "__name__", "")) getattr(lopt, "__name__", ""))
if node not in fgraph.apply_nodes: if node not in fgraph.apply_nodes:
# go to next node # go to next node
break break
...@@ -1884,8 +1888,8 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1884,8 +1888,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[gopt] += change_tracker.nb_imported - nb node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use: if global_process_count[gopt] > max_use:
max_use_abort = True max_use_abort = True
opt_name = (getattr(gopt, "name", None) opt_name = (getattr(gopt, "name", None) or
or getattr(gopt, "__name__", "")) getattr(gopt, "__name__", ""))
final_sub_profs.append(sub_profs) final_sub_profs.append(sub_profs)
global_opt_timing[-1] += time.time() - t_before_final_opt global_opt_timing[-1] += time.time() - t_before_final_opt
...@@ -1896,9 +1900,9 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1896,9 +1900,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
end_nb_nodes = len(fgraph.apply_nodes) end_nb_nodes = len(fgraph.apply_nodes)
if max_use_abort: if max_use_abort:
_logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name _logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name +
+ ". You can safely raise the current threshold of " ". You can safely raise the current threshold of " +
+ "%f with the theano flag 'optdb.max_use_ratio'." % "%f with the theano flag 'optdb.max_use_ratio'." %
config.optdb.max_use_ratio) config.optdb.max_use_ratio)
fgraph.remove_feature(change_tracker) fgraph.remove_feature(change_tracker)
return (self, loop_timing, loop_process_count, return (self, loop_timing, loop_process_count,
...@@ -1909,7 +1913,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1909,7 +1913,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
def print_summary(self, stream=sys.stdout, level=0, depth=-1): def print_summary(self, stream=sys.stdout, level=0, depth=-1):
name = getattr(self, 'name', None) name = getattr(self, 'name', None)
print("%s%s %s id=%i" % ( print("%s%s %s id=%i" % (
(' ' * level), self.__class__.__name__, name, id(self)), file=stream) (' ' * level), self.__class__.__name__, name, id(self)), file=stream)
if depth != 0: if depth != 0:
for lopt in self.get_local_optimizers(): for lopt in self.get_local_optimizers():
lopt.print_summary(stream, level=(level + 2), lopt.print_summary(stream, level=(level + 2),
...@@ -1925,11 +1929,11 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1925,11 +1929,11 @@ class EquilibriumOptimizer(NavigatorOptimizer):
blanc = (' ' * level) blanc = (' ' * level)
print(blanc, "EquilibriumOptimizer", end=' ', file=stream) print(blanc, "EquilibriumOptimizer", end=' ', file=stream)
print(blanc, getattr(opt, "name", print(blanc, getattr(opt, "name",
getattr(opt, "__name__", "")), file=stream) getattr(opt, "__name__", "")), file=stream)
print(blanc, " time %.3fs for %d passes" % ( print(blanc, " time %.3fs for %d passes" % (
sum(loop_timing), len(loop_timing)), file=stream) sum(loop_timing), len(loop_timing)), file=stream)
print(blanc, " nb nodes (start, end, max) %d %d %d" % ( print(blanc, " nb nodes (start, end, max) %d %d %d" % (
start_nb_nodes, end_nb_nodes, max_nb_nodes), file=stream) start_nb_nodes, end_nb_nodes, max_nb_nodes), file=stream)
print(blanc, " time io_toposort %.3fs" % sum( print(blanc, " time io_toposort %.3fs" % sum(
io_toposort_timing), file=stream) io_toposort_timing), file=stream)
s = sum([time_opts[o] for o in opt.get_local_optimizers()]) s = sum([time_opts[o] for o in opt.get_local_optimizers()])
...@@ -1948,12 +1952,12 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1948,12 +1952,12 @@ class EquilibriumOptimizer(NavigatorOptimizer):
if len(d) > 5: if len(d) > 5:
lopt += " ..." lopt += " ..."
print(blanc, (' %2d - %.3fs %d (%.3fs in global opts, ' print(blanc, (' %2d - %.3fs %d (%.3fs in global opts, '
'%.3fs io_toposort) - %d nodes - %s' % ( '%.3fs io_toposort) - %d nodes - %s' % (
i, loop_timing[i], i, loop_timing[i],
sum(loop_process_count[i].values()), sum(loop_process_count[i].values()),
global_opt_timing[i], global_opt_timing[i],
io_toposort_timing[i], nb_nodes[i], io_toposort_timing[i], nb_nodes[i],
lopt)), file=stream) lopt)), file=stream)
count_opt = [] count_opt = []
not_used = [] not_used = []
...@@ -1975,8 +1979,9 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -1975,8 +1979,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
not_used_time += time_opts[o] not_used_time += time_opts[o]
if count_opt: if count_opt:
print(blanc, \ print(blanc,
' times - times applied - nb node created - name:', file=stream) ' times - times applied - nb node created - name:',
file=stream)
count_opt.sort() count_opt.sort()
for (t, count, n_created, o) in count_opt[::-1]: for (t, count, n_created, o) in count_opt[::-1]:
print(blanc, ' %.3fs - %d - %d - %s' % ( print(blanc, ' %.3fs - %d - %d - %s' % (
...@@ -2010,7 +2015,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2010,7 +2015,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
@staticmethod @staticmethod
def merge_profile(prof1, prof2): def merge_profile(prof1, prof2):
#(opt, loop_timing, loop_process_count, max_nb_nodes, # (opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1 # global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union( local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union(
prof2[0].get_local_optimizers()) prof2[0].get_local_optimizers())
...@@ -2085,7 +2090,7 @@ class EquilibriumOptimizer(NavigatorOptimizer): ...@@ -2085,7 +2090,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
final_sub_profs) final_sub_profs)
################# #################
### Utilities ### # Utilities #
################# #################
...@@ -2096,7 +2101,7 @@ def _check_chain(r, chain): ...@@ -2096,7 +2101,7 @@ def _check_chain(r, chain):
while chain: while chain:
elem = chain.pop() elem = chain.pop()
if elem is None: if elem is None:
if not r.owner is None: if r.owner is not None:
return False return False
elif r.owner is None: elif r.owner is None:
return False return False
...@@ -2105,20 +2110,20 @@ def _check_chain(r, chain): ...@@ -2105,20 +2110,20 @@ def _check_chain(r, chain):
return False return False
else: else:
try: try:
if (issubclass(elem, op.Op) if (issubclass(elem, op.Op) and
and not isinstance(r.owner.op, elem)): not isinstance(r.owner.op, elem)):
return False return False
except TypeError: except TypeError:
return False return False
if chain: if chain:
r = r.owner.inputs[chain.pop()] r = r.owner.inputs[chain.pop()]
# print 'check_chain', _check_chain.n_calls # print 'check_chain', _check_chain.n_calls
#_check_chain.n_calls += 1 # _check_chain.n_calls += 1
# The return value will be used as a Boolean, but some Variables cannot # The return value will be used as a Boolean, but some Variables cannot
# be used as Booleans (the results of comparisons, for instance) # be used as Booleans (the results of comparisons, for instance)
return (r is not None) return (r is not None)
#_check_chain.n_calls = 0 # _check_chain.n_calls = 0
def check_chain(r, *chain): def check_chain(r, *chain):
......
...@@ -244,7 +244,6 @@ whitelist_flake8 = [ ...@@ -244,7 +244,6 @@ whitelist_flake8 = [
"gof/unify.py", "gof/unify.py",
"gof/graph.py", "gof/graph.py",
"gof/__init__.py", "gof/__init__.py",
"gof/opt.py",
"gof/link.py", "gof/link.py",
"gof/fg.py", "gof/fg.py",
"gof/op.py", "gof/op.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论