提交 42116a64 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made destroyhandler more deterministic

上级 2db0017d
...@@ -12,6 +12,8 @@ import theano ...@@ -12,6 +12,8 @@ import theano
import toolbox import toolbox
import graph import graph
from theano.gof.python25 import deque from theano.gof.python25 import deque
from theano.gof.python25 import OrderedDict
from theano.misc.ordered_set import OrderedSet
from fg import InconsistencyError from fg import InconsistencyError
...@@ -220,7 +222,7 @@ def add_impact(r, view_o, impact): ...@@ -220,7 +222,7 @@ def add_impact(r, view_o, impact):
add_impact(v, view_o, impact) add_impact(v, view_o, impact)
def get_impact(root, view_o): def get_impact(root, view_o):
impact = set() impact = OrderedSet()
add_impact(root, view_o, impact) add_impact(root, view_o, impact)
return impact return impact
...@@ -320,7 +322,7 @@ if 0: ...@@ -320,7 +322,7 @@ if 0:
fgraph.destroy_handler = self fgraph.destroy_handler = self
self.fgraph = fgraph self.fgraph = fgraph
self.destroyers = set() #set of Apply instances with non-null destroy_map self.destroyers = OrderedSet() #set of Apply instances with non-null destroy_map
self.view_i = {} # variable -> variable used in calculation self.view_i = {} # variable -> variable used in calculation
self.view_o = {} # variable -> set of variables that use this one as a direct input self.view_o = {} # variable -> set of variables that use this one as a direct input
#clients: how many times does an apply use a given variable #clients: how many times does an apply use a given variable
...@@ -402,7 +404,7 @@ if 0: ...@@ -402,7 +404,7 @@ if 0:
o = app.outputs[o_idx] o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]] i = app.inputs[i_idx_list[0]]
self.view_i[o] = i self.view_i[o] = i
self.view_o.setdefault(i,set()).add(o) self.view_o.setdefault(i, OrderedSet()).add(o)
# update self.clients # update self.clients
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
...@@ -420,7 +422,7 @@ if 0: ...@@ -420,7 +422,7 @@ if 0:
#self.debug_all_apps.remove(app) #self.debug_all_apps.remove(app)
#UPDATE self.clients #UPDATE self.clients
for i, input in enumerate(set(app.inputs)): for i, input in enumerate(OrderedSet(app.inputs)):
del self.clients[input][app] del self.clients[input][app]
if getattr(app.op, 'destroy_map', {}): if getattr(app.op, 'destroy_map', {}):
...@@ -480,7 +482,7 @@ if 0: ...@@ -480,7 +482,7 @@ if 0:
if not self.view_o[old_r]: if not self.view_o[old_r]:
del self.view_o[old_r] del self.view_o[old_r]
self.view_o.setdefault(new_r,set()).add(output) self.view_o.setdefault(new_r, OrderedSet()).add(output)
self.stale_droot = True self.stale_droot = True
...@@ -513,7 +515,7 @@ if 0: ...@@ -513,7 +515,7 @@ if 0:
c) an Apply destroys (illegally) one of its own inputs by aliasing c) an Apply destroys (illegally) one of its own inputs by aliasing
""" """
rval = {} rval = OrderedDict()
if self.destroyers: if self.destroyers:
# BUILD DATA STRUCTURES # BUILD DATA STRUCTURES
...@@ -574,11 +576,11 @@ if 0: ...@@ -574,11 +576,11 @@ if 0:
#CHECK FOR INPUT ALIASING #CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import # OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', []) tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerated = set(idx1 for idx0, idx1 in tolerate_same tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', []) tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
ignored = set(idx1 for idx0, idx1 in tolerate_aliased ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
#print 'tolerated', tolerated #print 'tolerated', tolerated
#print 'ignored', ignored #print 'ignored', ignored
...@@ -592,7 +594,7 @@ if 0: ...@@ -592,7 +594,7 @@ if 0:
# add the rule: app must be preceded by all other Apply instances that # add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input # depend on destroyed_input
root_clients = set() root_clients = OrderedSet()
for r in root_impact: for r in root_impact:
assert not [a for a,c in self.clients[r].items() if not c] assert not [a for a,c in self.clients[r].items() if not c]
root_clients.update([a for a,c in self.clients[r].items() if c]) root_clients.update([a for a,c in self.clients[r].items() if c])
...@@ -710,14 +712,14 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -710,14 +712,14 @@ class DestroyHandler(toolbox.Bookkeeper):
fgraph.destroy_handler = self fgraph.destroy_handler = self
self.fgraph = fgraph self.fgraph = fgraph
self.destroyers = set() #set of Apply instances with non-null destroy_map self.destroyers = OrderedSet() #set of Apply instances with non-null destroy_map
self.view_i = {} # variable -> variable used in calculation self.view_i = OrderedDict() # variable -> variable used in calculation
self.view_o = {} # variable -> set of variables that use this one as a direct input self.view_o = OrderedDict() # variable -> set of variables that use this one as a direct input
#clients: how many times does an apply use a given variable #clients: how many times does an apply use a given variable
self.clients = {} # variable -> apply -> ninputs self.clients = OrderedDict() # variable -> apply -> ninputs
self.stale_droot = True self.stale_droot = True
self.debug_all_apps = set() self.debug_all_apps = OrderedSet()
if self.do_imports_on_attach: if self.do_imports_on_attach:
toolbox.Bookkeeper.on_attach(self, fgraph) toolbox.Bookkeeper.on_attach(self, fgraph)
...@@ -728,9 +730,9 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -728,9 +730,9 @@ class DestroyHandler(toolbox.Bookkeeper):
(see docstrings for these properties above) (see docstrings for these properties above)
""" """
if self.stale_droot: if self.stale_droot:
droot = {} # destroyed view + nonview variables -> foundation droot = OrderedDict() # destroyed view + nonview variables -> foundation
impact = {} # destroyed nonview variable -> it + all views of it impact = OrderedDict() # destroyed nonview variable -> it + all views of it
root_destroyer = {} # root -> destroyer apply root_destroyer = OrderedDict() # root -> destroyer apply
for app in self.destroyers: for app in self.destroyers:
for output_idx, input_idx_list in app.op.destroy_map.items(): for output_idx, input_idx_list in app.op.destroy_map.items():
...@@ -775,25 +777,25 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -775,25 +777,25 @@ class DestroyHandler(toolbox.Bookkeeper):
#print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps) #print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list # If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', {}): if getattr(app.op, 'destroy_map', OrderedDict()):
self.destroyers.add(app) self.destroyers.add(app)
# add this symbol to the forward and backward maps # add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op)) raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op))
o = app.outputs[o_idx] o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]] i = app.inputs[i_idx_list[0]]
self.view_i[o] = i self.view_i[o] = i
self.view_o.setdefault(i,set()).add(o) self.view_o.setdefault(i, OrderedSet()).add(o)
# update self.clients # update self.clients
for i, input in enumerate(app.inputs): for i, input in enumerate(app.inputs):
self.clients.setdefault(input, {}).setdefault(app,0) self.clients.setdefault(input, OrderedDict()).setdefault(app,0)
self.clients[input][app] += 1 self.clients[input][app] += 1
for i, output in enumerate(app.outputs): for i, output in enumerate(app.outputs):
self.clients.setdefault(output, {}) self.clients.setdefault(output, OrderedDict())
self.stale_droot = True self.stale_droot = True
...@@ -803,10 +805,10 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -803,10 +805,10 @@ class DestroyHandler(toolbox.Bookkeeper):
self.debug_all_apps.remove(app) self.debug_all_apps.remove(app)
#UPDATE self.clients #UPDATE self.clients
for i, input in enumerate(set(app.inputs)): for i, input in enumerate(OrderedSet(app.inputs)):
del self.clients[input][app] del self.clients[input][app]
if getattr(app.op, 'destroy_map', {}): if getattr(app.op, 'destroy_map', OrderedDict()):
self.destroyers.remove(app) self.destroyers.remove(app)
# Note: leaving empty client dictionaries in the struct. # Note: leaving empty client dictionaries in the struct.
...@@ -814,7 +816,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -814,7 +816,7 @@ class DestroyHandler(toolbox.Bookkeeper):
# deleted on_detach(). # deleted on_detach().
#UPDATE self.view_i, self.view_o #UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs #destroying this output invalidates multiple inputs
raise NotImplementedError() raise NotImplementedError()
...@@ -843,11 +845,11 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -843,11 +845,11 @@ class DestroyHandler(toolbox.Bookkeeper):
if self.clients[old_r][app] == 0: if self.clients[old_r][app] == 0:
del self.clients[old_r][app] del self.clients[old_r][app]
self.clients.setdefault(new_r,{}).setdefault(app,0) self.clients.setdefault(new_r, OrderedDict()).setdefault(app,0)
self.clients[new_r][app] += 1 self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o #UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items(): for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
if len(i_idx_list) > 1: if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs #destroying this output invalidates multiple inputs
raise NotImplementedError() raise NotImplementedError()
...@@ -863,7 +865,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -863,7 +865,7 @@ class DestroyHandler(toolbox.Bookkeeper):
if not self.view_o[old_r]: if not self.view_o[old_r]:
del self.view_o[old_r] del self.view_o[old_r]
self.view_o.setdefault(new_r,set()).add(output) self.view_o.setdefault(new_r, OrderedSet()).add(output)
self.stale_droot = True self.stale_droot = True
...@@ -896,7 +898,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -896,7 +898,7 @@ class DestroyHandler(toolbox.Bookkeeper):
c) an Apply destroys (illegally) one of its own inputs by aliasing c) an Apply destroys (illegally) one of its own inputs by aliasing
""" """
rval = {} rval = OrderedDict()
if self.destroyers: if self.destroyers:
# BUILD DATA STRUCTURES # BUILD DATA STRUCTURES
...@@ -956,11 +958,11 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -956,11 +958,11 @@ class DestroyHandler(toolbox.Bookkeeper):
#CHECK FOR INPUT ALIASING #CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import # OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', []) tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerated = set(idx1 for idx0, idx1 in tolerate_same tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
tolerated.add(destroyed_idx) tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', []) tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
ignored = set(idx1 for idx0, idx1 in tolerate_aliased ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx) if idx0 == destroyed_idx)
#print 'tolerated', tolerated #print 'tolerated', tolerated
#print 'ignored', ignored #print 'ignored', ignored
...@@ -974,7 +976,7 @@ class DestroyHandler(toolbox.Bookkeeper): ...@@ -974,7 +976,7 @@ class DestroyHandler(toolbox.Bookkeeper):
# add the rule: app must be preceded by all other Apply instances that # add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input # depend on destroyed_input
root_clients = set() root_clients = OrderedSet()
for r in root_impact: for r in root_impact:
assert not [a for a,c in self.clients[r].items() if not c] assert not [a for a,c in self.clients[r].items() if not c]
root_clients.update([a for a,c in self.clients[r].items() if c]) root_clients.update([a for a,c in self.clients[r].items() if c])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论