提交 b6f3c469 authored 作者: abergeron's avatar abergeron

Merge pull request #1911 from nouiz/cycle

Fix crash du to opt that introduce cycle in the graph
......@@ -11,6 +11,7 @@ from theano.misc.ordered_set import OrderedSet
from fg import InconsistencyError
class ProtocolError(Exception):
"""Raised when FunctionGraph calls DestroyHandler callbacks in
an invalid way, for example, pruning or changing a node that has
......@@ -18,6 +19,7 @@ class ProtocolError(Exception):
"""
pass
def _contains_cycle(fgraph, orderings):
"""
......@@ -44,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
inputs = fgraph.inputs
outputs = fgraph.outputs
# this is hard-coded reimplementation of functions from graph.py
# reason: go faster, prepare for port to C.
# specifically, it could be replaced with a wrapper
......@@ -55,7 +56,6 @@ def _contains_cycle(fgraph, orderings):
# this is performance-critical code. it is the largest single-function
# bottleneck when compiling large graphs.
assert isinstance(outputs, (tuple, list, deque))
# TODO: For more speed - use a defaultdict for the orderings
......@@ -111,7 +111,7 @@ def _contains_cycle(fgraph, orderings):
# this is faster than calling get_parents
owner = var.owner
if owner:
parents = [ owner ]
parents = [owner]
else:
parents = []
......@@ -172,16 +172,16 @@ def _contains_cycle(fgraph, orderings):
# and increment the visited node count without double-counting
node = visitable.popleft()
visited += 1
for client in node_to_children.get(node,[]):
for client in node_to_children.get(node, []):
parent_counts[client] -= 1
# If all of a node's parents have been visited,
# it may now be visited too
if not parent_counts[client]:
visitable.append(client)
return visited != len(parent_counts)
def getroot(r, view_i):
"""
TODO: what is view_i ? based on add_impact's docstring, IG is guessing
......@@ -197,6 +197,7 @@ def getroot(r, view_i):
except KeyError:
return r
def add_impact(r, view_o, impact):
"""
In opposition to getroot, which finds the variable that is viewed *by* r, this function
......@@ -211,15 +212,17 @@ def add_impact(r, view_o, impact):
IG thinks so, based on reading the code. It looks like get_impact
does what this docstring said this function does.
"""
for v in view_o.get(r,[]):
for v in view_o.get(r, []):
impact.add(v)
add_impact(v, view_o, impact)
def get_impact(root, view_o):
impact = OrderedSet()
add_impact(root, view_o, impact)
return impact
def fast_inplace_check(inputs):
""" Return the variables in inputs that are posible candidate for as inputs of inplace operation
......@@ -227,12 +230,14 @@ def fast_inplace_check(inputs):
:param inputs: inputs Variable that you want to use as inplace destination
"""
fgraph = inputs[0].fgraph
protected_inputs = [f.protected for f in fgraph._features if isinstance(f,theano.compile.function_module.Supervisor)]
protected_inputs = sum(protected_inputs,[])#flatten the list
Supervisor = theano.compile.function_module.Supervisor
protected_inputs = [f.protected for f in fgraph._features
if isinstance(f, Supervisor)]
protected_inputs = sum(protected_inputs, []) # flatten the list
protected_inputs.extend(fgraph.outputs)
inputs = [i for i in inputs if
not isinstance(i,graph.Constant)
not isinstance(i, graph.Constant)
and not fgraph.destroyers(i)
and i not in protected_inputs]
return inputs
......@@ -294,14 +299,18 @@ if 0:
if self.fgraph is fgraph:
already_there = True
if self.fgraph not in [None, fgraph]:
raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)")
raise Exception("A DestroyHandler instance can only serve"
" one FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr):
already_there = True
if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
# FunctionGraph.attach_feature catches AlreadyThere
# and cancels the attachment
raise toolbox.AlreadyThere(
"DestroyHandler feature is already present or in"
" conflict with another plugin.")
####### end of checking ############
......@@ -342,7 +351,7 @@ if 0:
def _build_droot_impact(self):
droot = {} # destroyed view + nonview variables -> foundation
impact = {} # destroyed nonview variable -> it + all views of it
root_destroyer = {} # root -> destroyer apply
root_destroyer = {} # root -> destroyer apply
for app in self.destroyers:
for output_idx, input_idx_list in app.op.destroy_map.items():
......@@ -352,7 +361,8 @@ if 0:
input = app.inputs[input_idx]
input_root = getroot(input, self.view_i)
if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root)
raise InconsistencyError(
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
#input_impact = set([input_root])
......@@ -394,7 +404,9 @@ if 0:
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op))
raise NotImplementedError(
'destroying this output invalidates multiple inputs',
(app. op))
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
self.view_i[o] = i
......@@ -402,7 +414,7 @@ if 0:
# update self.clients
for i, input in enumerate(app.inputs):
self.clients.setdefault(input, {}).setdefault(app,0)
self.clients.setdefault(input, {}).setdefault(app, 0)
self.clients[input][app] += 1
for i, output in enumerate(app.outputs):
......@@ -460,7 +472,8 @@ if 0:
self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
for o_idx, i_idx_list in getattr(app.op, 'view_map',
{}).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
......@@ -493,7 +506,8 @@ if 0:
ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles")
raise InconsistencyError(
"Dependency graph contains cycles")
else:
#James's Conjecture:
#If there are no destructive ops, then there can be no cycles.
......@@ -518,13 +532,14 @@ if 0:
droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants
illegal_destroy = [r for r in droot if \
getattr(r.tag,'indestructible', False) or \
illegal_destroy = [r for r in droot if
getattr(r.tag,'indestructible', False) or
isinstance(r, graph.Constant)]
if illegal_destroy:
#print 'destroying illegally'
raise InconsistencyError("Attempting to destroy indestructible variables: %s" %
illegal_destroy)
raise InconsistencyError(
"Attempting to destroy indestructible variables: %s" %
illegal_destroy)
# add destroyed variable clients as computational dependencies
for app in self.destroyers:
......@@ -569,15 +584,20 @@ if 0:
#CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerate_same = getattr(app.op,
'destroyhandler_tolerate_same',
[])
assert isinstance(tolerate_same, list)
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx)
tolerated = OrderedSet(idx1 for idx0, idx1 in
tolerate_same
if idx0 == destroyed_idx)
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
tolerate_aliased = getattr(
app.op, 'destroyhandler_tolerate_aliased', [])
assert isinstance(tolerate_aliased, list)
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx)
ignored = OrderedSet(idx1 for idx0, idx1
in tolerate_aliased
if idx0 == destroyed_idx)
#print 'tolerated', tolerated
#print 'ignored', ignored
for i, input in enumerate(app.inputs):
......@@ -600,6 +620,7 @@ if 0:
return rval
class DestroyHandler(toolbox.Bookkeeper):
"""
The DestroyHandler class detects when a graph is impossible to evaluate
......@@ -642,7 +663,6 @@ class DestroyHandler(toolbox.Bookkeeper):
<unknown>
"""
def __init__(self, do_imports_on_attach=True):
self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach
......@@ -686,14 +706,18 @@ class DestroyHandler(toolbox.Bookkeeper):
if self.fgraph is fgraph:
already_there = True
if self.fgraph is not None:
raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)")
raise Exception(
"A DestroyHandler instance can only serve one"
" FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr):
already_there = True
if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.")
raise toolbox.AlreadyThere(
"DestroyHandler feature is already present"
" or in conflict with another plugin.")
####### Annotate the FunctionGraph ############
......@@ -738,7 +762,8 @@ class DestroyHandler(toolbox.Bookkeeper):
input = app.inputs[input_idx]
input_root = getroot(input, self.view_i)
if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root)
raise InconsistencyError(
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root
root_destroyer[input_root] = app
input_impact = get_impact(input_root, self.view_o)
......@@ -768,7 +793,8 @@ class DestroyHandler(toolbox.Bookkeeper):
def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed"""
if app in self.debug_all_apps: raise ProtocolError("double import")
if app in self.debug_all_apps:
raise ProtocolError("double import")
self.debug_all_apps.add(app)
#print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
......@@ -777,9 +803,12 @@ class DestroyHandler(toolbox.Bookkeeper):
self.destroyers.add(app)
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op))
raise NotImplementedError(
'destroying this output invalidates multiple inputs',
(app. op))
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
self.view_i[o] = i
......@@ -787,7 +816,7 @@ class DestroyHandler(toolbox.Bookkeeper):
# update self.clients
for i, input in enumerate(app.inputs):
self.clients.setdefault(input, OrderedDict()).setdefault(app,0)
self.clients.setdefault(input, OrderedDict()).setdefault(app, 0)
self.clients[input][app] += 1
for i, output in enumerate(app.outputs):
......@@ -797,7 +826,8 @@ class DestroyHandler(toolbox.Bookkeeper):
def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed"""
if app not in self.debug_all_apps: raise ProtocolError("prune without import")
if app not in self.debug_all_apps:
raise ProtocolError("prune without import")
self.debug_all_apps.remove(app)
#UPDATE self.clients
......@@ -812,7 +842,8 @@ class DestroyHandler(toolbox.Bookkeeper):
# deleted on_detach().
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
......@@ -834,7 +865,8 @@ class DestroyHandler(toolbox.Bookkeeper):
# considered 'outputs' of the graph.
pass
else:
if app not in self.debug_all_apps: raise ProtocolError("change without import")
if app not in self.debug_all_apps:
raise ProtocolError("change without import")
#UPDATE self.clients
self.clients[old_r][app] -= 1
......@@ -845,7 +877,8 @@ class DestroyHandler(toolbox.Bookkeeper):
self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
......@@ -882,6 +915,14 @@ class DestroyHandler(toolbox.Bookkeeper):
else:
#James's Conjecture:
#If there are no destructive ops, then there can be no cycles.
#FB: This isn't always True. It can happend that
#optimization introduce node that depend on itself. This
#is very rare and should not happen in general. It will be
#caught later. The error will be far from the source. But
#doing this conjecture should speed up compilation most of
#the time. The user should create such dependency except
#if he mess too much with the internal.
pass
return True
......
......@@ -593,8 +593,27 @@ class MergeOptimizer(Optimizer):
pairs_list = sched.pop()
success = True
for pairs in pairs_list:
# We must check again the equivalence, as the graph
# can have changed. If so, doing the replacement can
# introduce node that depend on itself. Doing the
# full check of such cycle everytimes is very time
# consumming. I think this double check is faster then
# doing the full cycle check. The full cycle check is
# skipped by validate() if the graph don't contain
# destroyers.
node = pairs[0][0]
candidate = pairs[0][1]
if node.owner and candidate.owner:
node = node.owner
candidate = candidate.owner
inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(
node.inputs, candidate.inputs))
# No need to compare the op again, as it don't change.
if not inputs_match:
continue
try:
fgraph.replace_all_validate(pairs, 'Merge')
fgraph.replace_all_validate(pairs, 'MergeOptimizer')
except InconsistencyError:
success = False
nb_fail += 1
......
......@@ -41,19 +41,19 @@ class DB(object):
raise ValueError('The name of the object cannot be an existing'
' tag or the name of another existing object.',
obj, name)
if self.name is not None:
tags = tags + (self.name,)
obj.name = name
# This restriction is there because in many place we suppose that
# something in the DB is there only once.
if getattr(obj, 'name', "") in self.__db__:
if obj.name in self.__db__:
raise ValueError('''You can\'t register the same optimization
multiple time in a DB. Tryed to register "%s" again under the new name "%s".
Use theano.gof.ProxyDB to work around that''' % (obj.name, name))
if self.name is not None:
tags = tags + (self.name,)
obj.name = name
self.__db__[name] = set([obj])
self._names.add(name)
self.__db__[obj.__class__.__name__].add(obj)
self.add_tags(name, *tags)
def add_tags(self, name, *tags):
......
......@@ -204,6 +204,7 @@ if __name__ == "__main__":
cuda version 6.0 5.5 5.0 4.2 4.1 4.0 3.2 3.0 # note
gpu
K6000/NOECC 0.06s
K40 0.07s
K20m/ECC 0.07s
K20/NOECC 0.07s
M2090 0.19s
......@@ -213,6 +214,7 @@ if __name__ == "__main__":
M2070-Q 0.48s 0.27s 0.32s
M2050(Amazon) 0.25s
C1060 0.46s
K600 1.04s
GTX Titan Black 0.05s
GTX Titan(D15U-50) 0.06s 0.06s don't work
......
......@@ -115,7 +115,7 @@ class InputToGpuOptimizer(Optimizer):
if new_input.type == input.type:
fgraph.replace_validate(input, new_input,
"InputToGpuOptimizer")
"InputToGpuOptimizer")
except TypeError:
#as we currently only support float32, this can fail.
#Using try except make that we won't need
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论