提交 70d8b3aa authored 作者: Chinnadhurai Sankar's avatar Chinnadhurai Sankar

fix bugs in doc strings/remove final set of flake8 errors

上级 15b8c753
from __future__ import absolute_import, print_function, division
from six.moves import reduce
from six import string_types
if 0:
class _EquilibriumOptimizer(NavigatorOptimizer):
def __init__(self,
local_optimizers,
failure_callback=None,
max_depth=None,
max_use_ratio=None):
super(EquilibriumOptimizer, self).__init__(
None,
ignore_newtrees=False,
failure_callback=failure_callback)
self.local_optimizers = local_optimizers
self.max_depth = max_depth
self.max_use_ratio = max_use_ratio
self.tracks = defaultdict(list)
self.tracks0 = defaultdict(list)
max_depth = 0
for lopt in local_optimizers:
tracks = lopt.tracks()
for track in tracks:
max_depth = max(max_depth, len(track))
if self.max_depth is not None and max_depth > self.max_depth:
raise ValueError('One of the local optimizers exceeds the maximal depth.')
for i, op in enumerate(track):
if i == 0:
self.tracks0[op].append((track, i, lopt))
self.tracks[op].append((track, i, lopt))
def fetch_tracks(self, op):
return self.tracks[op] + self.tracks[None]
def fetch_tracks0(self, op):
return self.tracks0[op] + self.tracks0[None]
def backtrack(self, node, tasks):
candidates = self.fetch_tracks(node.op)
tracks = []
def filter(node, depth):
new_candidates = []
for candidate in candidates:
track, i, lopt = candidate
if i < depth:
pass
elif track[i-depth] in (None, node.op):
if i == depth:
tasks[node].append(lopt)
else:
tracks.append(candidate)
else:
new_candidates.append(candidate)
return new_candidates
depth = 0
nodes = [node]
while candidates:
for node in nodes:
candidates = list(filter(node, depth))
depth += 1
_nodes = nodes
nodes = reduce(list.__iadd__,
[reduce(list.__iadd__,
[[n for n, i in out.clients if not isinstance(n, string_types)] for out in node.outputs],
[]) for node in nodes],
[])
candidates = tracks
tracks = []
def apply(self, fgraph):
tasks = defaultdict(list)
if self.max_use_ratio is not None:
max_uses = self.max_use_ratio * len(fgraph.apply_nodes)
runs = defaultdict(int)
else:
runs = None
def importer(node):
# print 'IMPORTING', node
self.backtrack(node, tasks)
def pruner(node):
try:
del tasks[node]
except KeyError:
pass
def chin(node, i, r, new_r):
if new_r.owner and not r.clients:
self.backtrack(new_r.owner, tasks)
# # == NOT IDEAL == #
# for node in fgraph.apply_nodes:
# importer(node)
for node in fgraph.toposort():
tasks[node].extend(lopt for track, i, lopt in self.fetch_tracks0(node.op))
u = self.attach_updater(fgraph, importer, pruner, chin)
print('KEYS', [hash(t) for t in tasks.keys()])
while tasks:
for node in tasks:
todo = tasks.pop(node)
break
for lopt in todo:
if runs is not None and runs[lopt] >= max_uses:
print('Warning: optimization exceeded its maximal use ratio: %s, %s' % (lopt, max_uses), file=sys.stderr)
continue
success = self.process_node(fgraph, node, lopt)
if success:
if runs is not None: runs[lopt] += 1
break
self.detach_updater(fgraph, u)
# def match(self, node, candidates):
# candidates[:] = [candidate
# for candidate in candidates
# if candidate.current.op is None or candidate.current.op == node.op]
# for candidate in candidates:
# if candidate.current.inputs is not None:
# for in1, in2 in zip(candidate.current.inputs, node.inputs):
# if isinstance(in1, string_types):
# candidate.match[in1] = in2
# for client in node.clients:
# op = node.op
# patterns = self.pattern_base[(depth, op)].union(self.pattern_base[(depth, WILDCARD)])
# if not patterns:
# return patterns
# return self.match(node, depth + 1).intersection(patterns)
# def backtrack(self, node, q):
# for node2, i in node.clients:
# op2 = node2.op
......@@ -15,7 +15,6 @@ from copy import copy
from functools import partial
from theano.gof.utils import ANY_TYPE, comm_guard, FALL_THROUGH, iteritems
################################
......@@ -227,7 +226,7 @@ def unify_walk(a, b, U):
return False
@comm_guard(FreeVariable, ANY_TYPE)
@comm_guard(FreeVariable, ANY_TYPE) # noqa
def unify_walk(fv, o, U):
"""
FreeV is unified to BoundVariable(other_object).
......@@ -237,7 +236,7 @@ def unify_walk(fv, o, U):
return U.merge(v, fv)
@comm_guard(BoundVariable, ANY_TYPE)
@comm_guard(BoundVariable, ANY_TYPE) # noqa
def unify_walk(bv, o, U):
"""
The unification succeed iff BV.value == other_object.
......@@ -249,7 +248,7 @@ def unify_walk(bv, o, U):
return False
@comm_guard(OrVariable, ANY_TYPE)
@comm_guard(OrVariable, ANY_TYPE) # noqa
def unify_walk(ov, o, U):
"""
The unification succeeds iff other_object in OrV.options.
......@@ -262,7 +261,7 @@ def unify_walk(ov, o, U):
return False
@comm_guard(NotVariable, ANY_TYPE)
@comm_guard(NotVariable, ANY_TYPE) # noqa
def unify_walk(nv, o, U):
"""
The unification succeeds iff other_object not in NV.not_options.
......@@ -275,7 +274,7 @@ def unify_walk(nv, o, U):
return U.merge(v, nv)
@comm_guard(FreeVariable, Variable)
@comm_guard(FreeVariable, Variable) # noqa
def unify_walk(fv, v, U):
"""
Both variables are unified.
......@@ -285,7 +284,7 @@ def unify_walk(fv, v, U):
return U.merge(v, fv)
@comm_guard(BoundVariable, Variable)
@comm_guard(BoundVariable, Variable) # noqa
def unify_walk(bv, v, U):
"""
V is unified to BV.value.
......@@ -294,13 +293,13 @@ def unify_walk(bv, v, U):
return unify_walk(v, bv.value, U)
@comm_guard(OrVariable, OrVariable)
@comm_guard(OrVariable, OrVariable) # noqa
def unify_walk(a, b, U):
"""
OrV(list1) == OrV(list2) == OrV(intersection(list1, list2))
"""
opt = intersection(a.options, b.options)
opt = a.options.intersection(b.options)
if not opt:
return False
elif len(opt) == 1:
......@@ -310,18 +309,18 @@ def unify_walk(a, b, U):
return U.merge(v, a, b)
@comm_guard(NotVariable, NotVariable)
@comm_guard(NotVariable, NotVariable) # noqa
def unify_walk(a, b, U):
"""
NV(list1) == NV(list2) == NV(union(list1, list2))
"""
opt = union(a.not_options, b.not_options)
opt = a.not_options.union(b.not_options)
v = NotVariable("?", opt)
return U.merge(v, a, b)
@comm_guard(OrVariable, NotVariable)
@comm_guard(OrVariable, NotVariable) # noqa
def unify_walk(o, n, U):
"""
OrV(list1) == NV(list2) == OrV(list1 \ list2)
......@@ -337,7 +336,7 @@ def unify_walk(o, n, U):
return U.merge(v, o, n)
@comm_guard(VariableInList, (list, tuple))
@comm_guard(VariableInList, (list, tuple)) # noqa
def unify_walk(vil, l, U):
"""
Unifies VIL's inner Variable to OrV(list).
......@@ -348,7 +347,7 @@ def unify_walk(vil, l, U):
return unify_walk(v, ov, U)
@comm_guard((list, tuple), (list, tuple))
@comm_guard((list, tuple), (list, tuple)) # noqa
def unify_walk(l1, l2, U):
"""
Tries to unify each corresponding pair of elements from l1 and l2.
......@@ -363,7 +362,7 @@ def unify_walk(l1, l2, U):
return U
@comm_guard(dict, dict)
@comm_guard(dict, dict) # noqa
def unify_walk(d1, d2, U):
"""
Tries to unify values of corresponding keys.
......@@ -377,7 +376,7 @@ def unify_walk(d1, d2, U):
return U
@comm_guard(ANY_TYPE, ANY_TYPE)
@comm_guard(ANY_TYPE, ANY_TYPE) # noqa
def unify_walk(a, b, U):
"""
Checks for the existence of the __unify_walk__ method for one of
......@@ -392,7 +391,7 @@ def unify_walk(a, b, U):
return FALL_THROUGH
@comm_guard(Variable, ANY_TYPE)
@comm_guard(Variable, ANY_TYPE) # noqa
def unify_walk(v, o, U):
"""
This simply checks if the Var has an unification in U and uses it
......@@ -427,27 +426,27 @@ def unify_merge(a, b, U):
return a
@comm_guard(Variable, ANY_TYPE)
@comm_guard(Variable, ANY_TYPE) # noqa
def unify_merge(v, o, U):
return v
@comm_guard(BoundVariable, ANY_TYPE)
@comm_guard(BoundVariable, ANY_TYPE) # noqa
def unify_merge(bv, o, U):
return bv.value
@comm_guard(VariableInList, (list, tuple))
@comm_guard(VariableInList, (list, tuple)) # noqa
def unify_merge(vil, l, U):
return [unify_merge(x, x, U) for x in l]
@comm_guard((list, tuple), (list, tuple))
@comm_guard((list, tuple), (list, tuple)) # noqa
def unify_merge(l1, l2, U):
return [unify_merge(x1, x2, U) for x1, x2 in zip(l1, l2)]
@comm_guard(dict, dict)
@comm_guard(dict, dict) # noqa
def unify_merge(d1, d2, U):
d = d1.__class__()
for k1, v1 in iteritems(d1):
......@@ -461,12 +460,12 @@ def unify_merge(d1, d2, U):
return d
@comm_guard(FVar, ANY_TYPE)
@comm_guard(FVar, ANY_TYPE) # noqa
def unify_merge(vs, o, U):
return vs(U)
@comm_guard(ANY_TYPE, ANY_TYPE)
@comm_guard(ANY_TYPE, ANY_TYPE) # noqa
def unify_merge(a, b, U):
if (not isinstance(a, Variable) and
not isinstance(b, Variable) and
......@@ -476,7 +475,7 @@ def unify_merge(a, b, U):
return FALL_THROUGH
@comm_guard(Variable, ANY_TYPE)
@comm_guard(Variable, ANY_TYPE) # noqa
def unify_merge(v, o, U):
"""
This simply checks if the Var has an unification in U and uses it
......
......@@ -125,12 +125,11 @@ whitelist_flake8 = [
"sparse/sandbox/test_sp.py",
"sparse/sandbox/sp2.py",
"sparse/sandbox/truedot.py",
"sparse/sandbox/sp.py",
"gof/__init__.py",
"sparse/sandbox/sp.py",
"gof/__init__.py",
"d3viz/__init__.py",
"d3viz/tests/__init__.py",
"gof/tests/__init__.py",
]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论