提交 b129fb77 authored 作者: Iban Harlouchet's avatar Iban Harlouchet

flake8 theano/gof/opt.py; 2 E left

上级 de08b763
......@@ -29,6 +29,7 @@ from . import destroyhandler as dh
_logger = logging.getLogger('theano.gof.opt')
_optimizer_idx = [0]
def _list_of_nodes(fgraph):
return list(graph.io_toposort(fgraph.inputs, fgraph.outputs))
......@@ -729,6 +730,7 @@ def pre_constant_merge(vars):
const_sig_inv = {}
if isinstance(vars, graph.Variable):
vars = [vars]
def recursive_merge(var):
if var in seen_var:
return var
......@@ -761,7 +763,7 @@ def pre_constant_merge(vars):
########################
### Local Optimizers ###
# Local Optimizers #
########################
class LocalOptimizer(object):
......@@ -820,9 +822,11 @@ class LocalOptimizer(object):
(' ' * level), self.__class__.__name__, id(self)), file=stream)
theano.configparser.AddConfigVar('metaopt.verbose',
theano.configparser.AddConfigVar(
'metaopt.verbose',
"Enable verbose output for meta optimizers",
theano.configparser.BoolParam(False), in_c_key=False)
theano.configparser.BoolParam(False),
in_c_key=False)
class LocalMetaOptimizer(LocalOptimizer):
......@@ -1217,6 +1221,7 @@ class PatternSub(LocalOptimizer):
if node.op != self.op:
return False
# TODO: if we remove pdb, do this speed things up?
def match(pattern, expr, u, allow_multiple_clients=False, pdb=False):
# TODO move outside match
def retry_with_equiv():
......@@ -1233,9 +1238,8 @@ class PatternSub(LocalOptimizer):
if isinstance(pattern, (list, tuple)):
if expr.owner is None:
return False
if (not (expr.owner.op == pattern[0])
or (not allow_multiple_clients
and len(expr.clients) > 1)):
if (not (expr.owner.op == pattern[0]) or
(not allow_multiple_clients and len(expr.clients) > 1)):
return retry_with_equiv()
if len(pattern) - 1 != len(expr.owner.inputs):
return retry_with_equiv()
......@@ -1263,16 +1267,16 @@ class PatternSub(LocalOptimizer):
return retry_with_equiv()
else:
u = u.merge(expr, v)
elif (isinstance(pattern, (int, float))
and isinstance(expr, graph.Constant)):
elif (isinstance(pattern, (int, float)) and
isinstance(expr, graph.Constant)):
if numpy.all(
theano.tensor.constant(pattern).value == expr.value):
return u
else:
return retry_with_equiv()
elif (isinstance(pattern, graph.Constant)
and isinstance(expr, graph.Constant)
and pattern.equals(expr)):
elif (isinstance(pattern, graph.Constant) and
isinstance(expr, graph.Constant) and
pattern.equals(expr)):
return u
else:
return retry_with_equiv()
......@@ -1335,7 +1339,7 @@ class PatternSub(LocalOptimizer):
##################
### Navigators ###
# Navigators #
##################
# Use the following classes to apply LocalOptimizers
......@@ -1811,8 +1815,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use:
max_use_abort = True
opt_name = (getattr(gopt, "name", None)
or getattr(gopt, "__name__", ""))
opt_name = (getattr(gopt, "name", None) or
getattr(gopt, "__name__", ""))
global_sub_profs.append(sub_profs)
global_opt_timing.append(float(time.time() - t0))
......@@ -1858,8 +1862,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[lopt] += change_tracker.nb_imported - nb
if global_process_count[lopt] > max_use:
max_use_abort = True
opt_name = (getattr(lopt, "name", None)
or getattr(lopt, "__name__", ""))
opt_name = (getattr(lopt, "name", None) or
getattr(lopt, "__name__", ""))
if node not in fgraph.apply_nodes:
# go to next node
break
......@@ -1884,8 +1888,8 @@ class EquilibriumOptimizer(NavigatorOptimizer):
node_created[gopt] += change_tracker.nb_imported - nb
if global_process_count[gopt] > max_use:
max_use_abort = True
opt_name = (getattr(gopt, "name", None)
or getattr(gopt, "__name__", ""))
opt_name = (getattr(gopt, "name", None) or
getattr(gopt, "__name__", ""))
final_sub_profs.append(sub_profs)
global_opt_timing[-1] += time.time() - t_before_final_opt
......@@ -1896,9 +1900,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
end_nb_nodes = len(fgraph.apply_nodes)
if max_use_abort:
_logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name
+ ". You can safely raise the current threshold of "
+ "%f with the theano flag 'optdb.max_use_ratio'." %
_logger.error("EquilibriumOptimizer max'ed out by '%s'" % opt_name +
". You can safely raise the current threshold of " +
"%f with the theano flag 'optdb.max_use_ratio'." %
config.optdb.max_use_ratio)
fgraph.remove_feature(change_tracker)
return (self, loop_timing, loop_process_count,
......@@ -1975,8 +1979,9 @@ class EquilibriumOptimizer(NavigatorOptimizer):
not_used_time += time_opts[o]
if count_opt:
print(blanc, \
' times - times applied - nb node created - name:', file=stream)
print(blanc,
' times - times applied - nb node created - name:',
file=stream)
count_opt.sort()
for (t, count, n_created, o) in count_opt[::-1]:
print(blanc, ' %.3fs - %d - %d - %s' % (
......@@ -2010,7 +2015,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
@staticmethod
def merge_profile(prof1, prof2):
#(opt, loop_timing, loop_process_count, max_nb_nodes,
# (opt, loop_timing, loop_process_count, max_nb_nodes,
# global_opt_timing, nb_nodes, time_opts, io_toposort_timing) = prof1
local_optimizers = OrderedSet(prof1[0].get_local_optimizers()).union(
prof2[0].get_local_optimizers())
......@@ -2085,7 +2090,7 @@ class EquilibriumOptimizer(NavigatorOptimizer):
final_sub_profs)
#################
### Utilities ###
# Utilities #
#################
......@@ -2096,7 +2101,7 @@ def _check_chain(r, chain):
while chain:
elem = chain.pop()
if elem is None:
if not r.owner is None:
if r.owner is not None:
return False
elif r.owner is None:
return False
......@@ -2105,20 +2110,20 @@ def _check_chain(r, chain):
return False
else:
try:
if (issubclass(elem, op.Op)
and not isinstance(r.owner.op, elem)):
if (issubclass(elem, op.Op) and
not isinstance(r.owner.op, elem)):
return False
except TypeError:
return False
if chain:
r = r.owner.inputs[chain.pop()]
# print 'check_chain', _check_chain.n_calls
#_check_chain.n_calls += 1
# _check_chain.n_calls += 1
# The return value will be used as a Boolean, but some Variables cannot
# be used as Booleans (the results of comparisons, for instance)
return (r is not None)
#_check_chain.n_calls = 0
# _check_chain.n_calls = 0
def check_chain(r, *chain):
......
......@@ -244,7 +244,6 @@ whitelist_flake8 = [
"gof/unify.py",
"gof/graph.py",
"gof/__init__.py",
"gof/opt.py",
"gof/link.py",
"gof/fg.py",
"gof/op.py",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论