提交 b6f3c469 authored 作者: abergeron's avatar abergeron

Merge pull request #1911 from nouiz/cycle

Fix crash du to opt that introduce cycle in the graph
...@@ -11,6 +11,7 @@ from theano.misc.ordered_set import OrderedSet ...@@ -11,6 +11,7 @@ from theano.misc.ordered_set import OrderedSet
from fg import InconsistencyError from fg import InconsistencyError
class ProtocolError(Exception): class ProtocolError(Exception):
"""Raised when FunctionGraph calls DestroyHandler callbacks in """Raised when FunctionGraph calls DestroyHandler callbacks in
an invalid way, for example, pruning or changing a node that has an invalid way, for example, pruning or changing a node that has
...@@ -18,6 +19,7 @@ class ProtocolError(Exception): ...@@ -18,6 +19,7 @@ class ProtocolError(Exception):
""" """
pass pass
def _contains_cycle(fgraph, orderings): def _contains_cycle(fgraph, orderings):
""" """
...@@ -44,7 +46,6 @@ def _contains_cycle(fgraph, orderings): ...@@ -44,7 +46,6 @@ def _contains_cycle(fgraph, orderings):
inputs = fgraph.inputs inputs = fgraph.inputs
outputs = fgraph.outputs outputs = fgraph.outputs
# this is hard-coded reimplementation of functions from graph.py # this is hard-coded reimplementation of functions from graph.py
# reason: go faster, prepare for port to C. # reason: go faster, prepare for port to C.
# specifically, it could be replaced with a wrapper # specifically, it could be replaced with a wrapper
...@@ -55,7 +56,6 @@ def _contains_cycle(fgraph, orderings): ...@@ -55,7 +56,6 @@ def _contains_cycle(fgraph, orderings):
# this is performance-critical code. it is the largest single-function # this is performance-critical code. it is the largest single-function
# bottleneck when compiling large graphs. # bottleneck when compiling large graphs.
assert isinstance(outputs, (tuple, list, deque)) assert isinstance(outputs, (tuple, list, deque))
# TODO: For more speed - use a defaultdict for the orderings # TODO: For more speed - use a defaultdict for the orderings
...@@ -111,7 +111,7 @@ def _contains_cycle(fgraph, orderings): ...@@ -111,7 +111,7 @@ def _contains_cycle(fgraph, orderings):
# this is faster than calling get_parents # this is faster than calling get_parents
owner = var.owner owner = var.owner
if owner: if owner:
parents = [ owner ] parents = [owner]
else: else:
parents = [] parents = []
...@@ -172,16 +172,16 @@ def _contains_cycle(fgraph, orderings): ...@@ -172,16 +172,16 @@ def _contains_cycle(fgraph, orderings):
# and increment the visited node count without double-counting # and increment the visited node count without double-counting
node = visitable.popleft() node = visitable.popleft()
visited += 1 visited += 1
for client in node_to_children.get(node,[]): for client in node_to_children.get(node, []):
parent_counts[client] -= 1 parent_counts[client] -= 1
# If all of a node's parents have been visited, # If all of a node's parents have been visited,
# it may now be visited too # it may now be visited too
if not parent_counts[client]: if not parent_counts[client]:
visitable.append(client) visitable.append(client)
return visited != len(parent_counts) return visited != len(parent_counts)
def getroot(r, view_i): def getroot(r, view_i):
""" """
TODO: what is view_i ? based on add_impact's docstring, IG is guessing TODO: what is view_i ? based on add_impact's docstring, IG is guessing
...@@ -197,6 +197,7 @@ def getroot(r, view_i): ...@@ -197,6 +197,7 @@ def getroot(r, view_i):
except KeyError: except KeyError:
return r return r
def add_impact(r, view_o, impact): def add_impact(r, view_o, impact):
""" """
In opposition to getroot, which finds the variable that is viewed *by* r, this function In opposition to getroot, which finds the variable that is viewed *by* r, this function
...@@ -211,15 +212,17 @@ def add_impact(r, view_o, impact): ...@@ -211,15 +212,17 @@ def add_impact(r, view_o, impact):
IG thinks so, based on reading the code. It looks like get_impact IG thinks so, based on reading the code. It looks like get_impact
does what this docstring said this function does. does what this docstring said this function does.
""" """
for v in view_o.get(r,[]): for v in view_o.get(r, []):
impact.add(v) impact.add(v)
add_impact(v, view_o, impact) add_impact(v, view_o, impact)
def get_impact(root, view_o): def get_impact(root, view_o):
impact = OrderedSet() impact = OrderedSet()
add_impact(root, view_o, impact) add_impact(root, view_o, impact)
return impact return impact
def fast_inplace_check(inputs): def fast_inplace_check(inputs):
""" Return the variables in inputs that are posible candidate for as inputs of inplace operation """ Return the variables in inputs that are posible candidate for as inputs of inplace operation
...@@ -227,12 +230,14 @@ def fast_inplace_check(inputs): ...@@ -227,12 +230,14 @@ def fast_inplace_check(inputs):
:param inputs: inputs Variable that you want to use as inplace destination :param inputs: inputs Variable that you want to use as inplace destination
""" """
fgraph = inputs[0].fgraph fgraph = inputs[0].fgraph
protected_inputs = [f.protected for f in fgraph._features if isinstance(f,theano.compile.function_module.Supervisor)] Supervisor = theano.compile.function_module.Supervisor
protected_inputs = sum(protected_inputs,[])#flatten the list protected_inputs = [f.protected for f in fgraph._features
if isinstance(f, Supervisor)]
protected_inputs = sum(protected_inputs, []) # flatten the list
protected_inputs.extend(fgraph.outputs) protected_inputs.extend(fgraph.outputs)
inputs = [i for i in inputs if inputs = [i for i in inputs if
not isinstance(i,graph.Constant) not isinstance(i, graph.Constant)
and not fgraph.destroyers(i) and not fgraph.destroyers(i)
and i not in protected_inputs] and i not in protected_inputs]
return inputs return inputs
...@@ -294,14 +299,18 @@ if 0: ...@@ -294,14 +299,18 @@ if 0:
if self.fgraph is fgraph: if self.fgraph is fgraph:
already_there = True already_there = True
if self.fgraph not in [None, fgraph]: if self.fgraph not in [None, fgraph]:
raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)") raise Exception("A DestroyHandler instance can only serve"
" one FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'): for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr): if hasattr(fgraph, attr):
already_there = True already_there = True
if already_there: if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment # FunctionGraph.attach_feature catches AlreadyThere
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.") # and cancels the attachment
raise toolbox.AlreadyThere(
"DestroyHandler feature is already present or in"
" conflict with another plugin.")
####### end of checking ############ ####### end of checking ############
...@@ -352,7 +361,8 @@ if 0: ...@@ -352,7 +361,8 @@ if 0:
input = app.inputs[input_idx] input = app.inputs[input_idx]
input_root = getroot(input, self.view_i) input_root = getroot(input, self.view_i)
if input_root in droot: if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root) raise InconsistencyError(
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root droot[input_root] = input_root
root_destroyer[input_root] = app root_destroyer[input_root] = app
#input_impact = set([input_root]) #input_impact = set([input_root])
...@@ -394,7 +404,9 @@ if 0: ...@@ -394,7 +404,9 @@ if 0:
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op)) raise NotImplementedError(
'destroying this output invalidates multiple inputs',
(app. op))
o = app.outputs[o_idx] o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]] i = app.inputs[i_idx_list[0]]
self.view_i[o] = i self.view_i[o] = i
...@@ -402,7 +414,7 @@ if 0: ...@@ -402,7 +414,7 @@ if 0:
# update self.clients # update self.clients
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
self.clients.setdefault(input, {}).setdefault(app,0) self.clients.setdefault(input, {}).setdefault(app, 0)
self.clients[input][app] += 1 self.clients[input][app] += 1
for i, output in enumerate(app.outputs): for i, output in enumerate(app.outputs):
...@@ -460,7 +472,8 @@ if 0: ...@@ -460,7 +472,8 @@ if 0:
self.clients[new_r][app] += 1 self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o #UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map',
{}).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs #destroying this output invalidates multiple inputs
raise NotImplementedError() raise NotImplementedError()
...@@ -493,7 +506,8 @@ if 0: ...@@ -493,7 +506,8 @@ if 0:
ords = self.orderings(fgraph) ords = self.orderings(fgraph)
if _contains_cycle(fgraph, ords): if _contains_cycle(fgraph, ords):
raise InconsistencyError("Dependency graph contains cycles") raise InconsistencyError(
"Dependency graph contains cycles")
else: else:
#James's Conjecture: #James's Conjecture:
#If there are no destructive ops, then there can be no cycles. #If there are no destructive ops, then there can be no cycles.
...@@ -518,12 +532,13 @@ if 0: ...@@ -518,12 +532,13 @@ if 0:
droot, impact, __ignore = self.refresh_droot_impact() droot, impact, __ignore = self.refresh_droot_impact()
# check for destruction of constants # check for destruction of constants
illegal_destroy = [r for r in droot if \ illegal_destroy = [r for r in droot if
getattr(r.tag,'indestructible', False) or \ getattr(r.tag,'indestructible', False) or
isinstance(r, graph.Constant)] isinstance(r, graph.Constant)]
if illegal_destroy: if illegal_destroy:
#print 'destroying illegally' #print 'destroying illegally'
raise InconsistencyError("Attempting to destroy indestructible variables: %s" % raise InconsistencyError(
"Attempting to destroy indestructible variables: %s" %
illegal_destroy) illegal_destroy)
# add destroyed variable clients as computational dependencies # add destroyed variable clients as computational dependencies
...@@ -569,14 +584,19 @@ if 0: ...@@ -569,14 +584,19 @@ if 0:
#CHECK FOR INPUT ALIASING #CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import # OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', []) tolerate_same = getattr(app.op,
'destroyhandler_tolerate_same',
[])
assert isinstance(tolerate_same, list) assert isinstance(tolerate_same, list)
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same tolerated = OrderedSet(idx1 for idx0, idx1 in
tolerate_same
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', []) tolerate_aliased = getattr(
app.op, 'destroyhandler_tolerate_aliased', [])
assert isinstance(tolerate_aliased, list) assert isinstance(tolerate_aliased, list)
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased ignored = OrderedSet(idx1 for idx0, idx1
in tolerate_aliased
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
#print 'tolerated', tolerated #print 'tolerated', tolerated
#print 'ignored', ignored #print 'ignored', ignored
...@@ -600,6 +620,7 @@ if 0: ...@@ -600,6 +620,7 @@ if 0:
return rval return rval
class DestroyHandler(toolbox.Bookkeeper): class DestroyHandler(toolbox.Bookkeeper):
""" """
The DestroyHandler class detects when a graph is impossible to evaluate The DestroyHandler class detects when a graph is impossible to evaluate
...@@ -642,7 +663,6 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -642,7 +663,6 @@ class DestroyHandler(toolbox.Bookkeeper):
<unknown> <unknown>
""" """
def __init__(self, do_imports_on_attach=True): def __init__(self, do_imports_on_attach=True):
self.fgraph = None self.fgraph = None
self.do_imports_on_attach = do_imports_on_attach self.do_imports_on_attach = do_imports_on_attach
...@@ -686,14 +706,18 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -686,14 +706,18 @@ class DestroyHandler(toolbox.Bookkeeper):
if self.fgraph is fgraph: if self.fgraph is fgraph:
already_there = True already_there = True
if self.fgraph is not None: if self.fgraph is not None:
raise Exception("A DestroyHandler instance can only serve one FunctionGraph. (Matthew 6:24)") raise Exception(
"A DestroyHandler instance can only serve one"
" FunctionGraph. (Matthew 6:24)")
for attr in ('destroyers', 'destroy_handler'): for attr in ('destroyers', 'destroy_handler'):
if hasattr(fgraph, attr): if hasattr(fgraph, attr):
already_there = True already_there = True
if already_there: if already_there:
# FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment # FunctionGraph.attach_feature catches AlreadyThere and cancels the attachment
raise toolbox.AlreadyThere("DestroyHandler feature is already present or in conflict with another plugin.") raise toolbox.AlreadyThere(
"DestroyHandler feature is already present"
" or in conflict with another plugin.")
####### Annotate the FunctionGraph ############ ####### Annotate the FunctionGraph ############
...@@ -738,7 +762,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -738,7 +762,8 @@ class DestroyHandler(toolbox.Bookkeeper):
input = app.inputs[input_idx] input = app.inputs[input_idx]
input_root = getroot(input, self.view_i) input_root = getroot(input, self.view_i)
if input_root in droot: if input_root in droot:
raise InconsistencyError("Multiple destroyers of %s" % input_root) raise InconsistencyError(
"Multiple destroyers of %s" % input_root)
droot[input_root] = input_root droot[input_root] = input_root
root_destroyer[input_root] = app root_destroyer[input_root] = app
input_impact = get_impact(input_root, self.view_o) input_impact = get_impact(input_root, self.view_o)
...@@ -768,7 +793,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -768,7 +793,8 @@ class DestroyHandler(toolbox.Bookkeeper):
def on_import(self, fgraph, app, reason): def on_import(self, fgraph, app, reason):
"""Add Apply instance to set which must be computed""" """Add Apply instance to set which must be computed"""
if app in self.debug_all_apps: raise ProtocolError("double import") if app in self.debug_all_apps:
raise ProtocolError("double import")
self.debug_all_apps.add(app) self.debug_all_apps.add(app)
#print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) #print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
...@@ -777,9 +803,12 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -777,9 +803,12 @@ class DestroyHandler(toolbox.Bookkeeper):
self.destroyers.add(app) self.destroyers.add(app)
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op)) raise NotImplementedError(
'destroying this output invalidates multiple inputs',
(app. op))
o = app.outputs[o_idx] o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]] i = app.inputs[i_idx_list[0]]
self.view_i[o] = i self.view_i[o] = i
...@@ -787,7 +816,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -787,7 +816,7 @@ class DestroyHandler(toolbox.Bookkeeper):
# update self.clients # update self.clients
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
self.clients.setdefault(input, OrderedDict()).setdefault(app,0) self.clients.setdefault(input, OrderedDict()).setdefault(app, 0)
self.clients[input][app] += 1 self.clients[input][app] += 1
for i, output in enumerate(app.outputs): for i, output in enumerate(app.outputs):
...@@ -797,7 +826,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -797,7 +826,8 @@ class DestroyHandler(toolbox.Bookkeeper):
def on_prune(self, fgraph, app, reason): def on_prune(self, fgraph, app, reason):
"""Remove Apply instance from set which must be computed""" """Remove Apply instance from set which must be computed"""
if app not in self.debug_all_apps: raise ProtocolError("prune without import") if app not in self.debug_all_apps:
raise ProtocolError("prune without import")
self.debug_all_apps.remove(app) self.debug_all_apps.remove(app)
#UPDATE self.clients #UPDATE self.clients
...@@ -812,7 +842,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -812,7 +842,8 @@ class DestroyHandler(toolbox.Bookkeeper):
# deleted on_detach(). # deleted on_detach().
#UPDATE self.view_i, self.view_o #UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs #destroying this output invalidates multiple inputs
raise NotImplementedError() raise NotImplementedError()
...@@ -834,7 +865,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -834,7 +865,8 @@ class DestroyHandler(toolbox.Bookkeeper):
# considered 'outputs' of the graph. # considered 'outputs' of the graph.
pass pass
else: else:
if app not in self.debug_all_apps: raise ProtocolError("change without import") if app not in self.debug_all_apps:
raise ProtocolError("change without import")
#UPDATE self.clients #UPDATE self.clients
self.clients[old_r][app] -= 1 self.clients[old_r][app] -= 1
...@@ -845,7 +877,8 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -845,7 +877,8 @@ class DestroyHandler(toolbox.Bookkeeper):
self.clients[new_r][app] += 1 self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o #UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map',
OrderedDict()).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs #destroying this output invalidates multiple inputs
raise NotImplementedError() raise NotImplementedError()
...@@ -882,6 +915,14 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -882,6 +915,14 @@ class DestroyHandler(toolbox.Bookkeeper):
else: else:
#James's Conjecture: #James's Conjecture:
#If there are no destructive ops, then there can be no cycles. #If there are no destructive ops, then there can be no cycles.
#FB: This isn't always True. It can happend that
#optimization introduce node that depend on itself. This
#is very rare and should not happen in general. It will be
#caught later. The error will be far from the source. But
#doing this conjecture should speed up compilation most of
#the time. The user should create such dependency except
#if he mess too much with the internal.
pass pass
return True return True
......
...@@ -593,8 +593,27 @@ class MergeOptimizer(Optimizer): ...@@ -593,8 +593,27 @@ class MergeOptimizer(Optimizer):
pairs_list = sched.pop() pairs_list = sched.pop()
success = True success = True
for pairs in pairs_list: for pairs in pairs_list:
# We must check again the equivalence, as the graph
# can have changed. If so, doing the replacement can
# introduce node that depend on itself. Doing the
# full check of such cycle everytimes is very time
# consumming. I think this double check is faster then
# doing the full cycle check. The full cycle check is
# skipped by validate() if the graph don't contain
# destroyers.
node = pairs[0][0]
candidate = pairs[0][1]
if node.owner and candidate.owner:
node = node.owner
candidate = candidate.owner
inputs_match = all(node_in is cand_in
for node_in, cand_in in zip(
node.inputs, candidate.inputs))
# No need to compare the op again, as it don't change.
if not inputs_match:
continue
try: try:
fgraph.replace_all_validate(pairs, 'Merge') fgraph.replace_all_validate(pairs, 'MergeOptimizer')
except InconsistencyError: except InconsistencyError:
success = False success = False
nb_fail += 1 nb_fail += 1
......
...@@ -41,19 +41,19 @@ class DB(object): ...@@ -41,19 +41,19 @@ class DB(object):
raise ValueError('The name of the object cannot be an existing' raise ValueError('The name of the object cannot be an existing'
' tag or the name of another existing object.', ' tag or the name of another existing object.',
obj, name) obj, name)
if self.name is not None:
tags = tags + (self.name,)
obj.name = name
# This restriction is there because in many place we suppose that # This restriction is there because in many place we suppose that
# something in the DB is there only once. # something in the DB is there only once.
if getattr(obj, 'name', "") in self.__db__: if obj.name in self.__db__:
raise ValueError('''You can\'t register the same optimization raise ValueError('''You can\'t register the same optimization
multiple time in a DB. Tryed to register "%s" again under the new name "%s". multiple time in a DB. Tryed to register "%s" again under the new name "%s".
Use theano.gof.ProxyDB to work around that''' % (obj.name, name)) Use theano.gof.ProxyDB to work around that''' % (obj.name, name))
if self.name is not None:
tags = tags + (self.name,)
obj.name = name
self.__db__[name] = set([obj]) self.__db__[name] = set([obj])
self._names.add(name) self._names.add(name)
self.__db__[obj.__class__.__name__].add(obj)
self.add_tags(name, *tags) self.add_tags(name, *tags)
def add_tags(self, name, *tags): def add_tags(self, name, *tags):
......
...@@ -204,6 +204,7 @@ if __name__ == "__main__": ...@@ -204,6 +204,7 @@ if __name__ == "__main__":
cuda version 6.0 5.5 5.0 4.2 4.1 4.0 3.2 3.0 # note cuda version 6.0 5.5 5.0 4.2 4.1 4.0 3.2 3.0 # note
gpu gpu
K6000/NOECC 0.06s K6000/NOECC 0.06s
K40 0.07s
K20m/ECC 0.07s K20m/ECC 0.07s
K20/NOECC 0.07s K20/NOECC 0.07s
M2090 0.19s M2090 0.19s
...@@ -213,6 +214,7 @@ if __name__ == "__main__": ...@@ -213,6 +214,7 @@ if __name__ == "__main__":
M2070-Q 0.48s 0.27s 0.32s M2070-Q 0.48s 0.27s 0.32s
M2050(Amazon) 0.25s M2050(Amazon) 0.25s
C1060 0.46s C1060 0.46s
K600 1.04s
GTX Titan Black 0.05s GTX Titan Black 0.05s
GTX Titan(D15U-50) 0.06s 0.06s don't work GTX Titan(D15U-50) 0.06s 0.06s don't work
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论