提交 42116a64 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made destroyhandler more deterministic

上级 2db0017d
......@@ -12,6 +12,8 @@ import theano
import toolbox
import graph
from theano.gof.python25 import deque
from theano.gof.python25 import OrderedDict
from theano.misc.ordered_set import OrderedSet
from fg import InconsistencyError
......@@ -220,7 +222,7 @@ def add_impact(r, view_o, impact):
add_impact(v, view_o, impact)
def get_impact(root, view_o):
impact = set()
impact = OrderedSet()
add_impact(root, view_o, impact)
return impact
......@@ -320,7 +322,7 @@ if 0:
fgraph.destroy_handler = self
self.fgraph = fgraph
self.destroyers = set() #set of Apply instances with non-null destroy_map
self.destroyers = OrderedSet() #set of Apply instances with non-null destroy_map
self.view_i = {} # variable -> variable used in calculation
self.view_o = {} # variable -> set of variables that use this one as a direct input
#clients: how many times does an apply use a given variable
......@@ -402,7 +404,7 @@ if 0:
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
self.view_i[o] = i
self.view_o.setdefault(i,set()).add(o)
self.view_o.setdefault(i, OrderedSet()).add(o)
# update self.clients
for i, input in enumerate(app.inputs):
......@@ -420,7 +422,7 @@ if 0:
#self.debug_all_apps.remove(app)
#UPDATE self.clients
for i, input in enumerate(set(app.inputs)):
for i, input in enumerate(OrderedSet(app.inputs)):
del self.clients[input][app]
if getattr(app.op, 'destroy_map', {}):
......@@ -480,7 +482,7 @@ if 0:
if not self.view_o[old_r]:
del self.view_o[old_r]
self.view_o.setdefault(new_r,set()).add(output)
self.view_o.setdefault(new_r, OrderedSet()).add(output)
self.stale_droot = True
......@@ -513,7 +515,7 @@ if 0:
c) an Apply destroys (illegally) one of its own inputs by aliasing
"""
rval = {}
rval = OrderedDict()
if self.destroyers:
# BUILD DATA STRUCTURES
......@@ -574,11 +576,11 @@ if 0:
#CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerated = set(idx1 for idx0, idx1 in tolerate_same
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx)
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
ignored = set(idx1 for idx0, idx1 in tolerate_aliased
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx)
#print 'tolerated', tolerated
#print 'ignored', ignored
......@@ -592,7 +594,7 @@ if 0:
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
root_clients = set()
root_clients = OrderedSet()
for r in root_impact:
assert not [a for a,c in self.clients[r].items() if not c]
root_clients.update([a for a,c in self.clients[r].items() if c])
......@@ -710,14 +712,14 @@ class DestroyHandler(toolbox.Bookkeeper):
fgraph.destroy_handler = self
self.fgraph = fgraph
self.destroyers = set() #set of Apply instances with non-null destroy_map
self.view_i = {} # variable -> variable used in calculation
self.view_o = {} # variable -> set of variables that use this one as a direct input
self.destroyers = OrderedSet() #set of Apply instances with non-null destroy_map
self.view_i = OrderedDict() # variable -> variable used in calculation
self.view_o = OrderedDict() # variable -> set of variables that use this one as a direct input
#clients: how many times does an apply use a given variable
self.clients = {} # variable -> apply -> ninputs
self.clients = OrderedDict() # variable -> apply -> ninputs
self.stale_droot = True
self.debug_all_apps = set()
self.debug_all_apps = OrderedSet()
if self.do_imports_on_attach:
toolbox.Bookkeeper.on_attach(self, fgraph)
......@@ -728,9 +730,9 @@ class DestroyHandler(toolbox.Bookkeeper):
(see docstrings for these properties above)
"""
if self.stale_droot:
droot = {} # destroyed view + nonview variables -> foundation
impact = {} # destroyed nonview variable -> it + all views of it
root_destroyer = {} # root -> destroyer apply
droot = OrderedDict() # destroyed view + nonview variables -> foundation
impact = OrderedDict() # destroyed nonview variable -> it + all views of it
root_destroyer = OrderedDict() # root -> destroyer apply
for app in self.destroyers:
for output_idx, input_idx_list in app.op.destroy_map.items():
......@@ -775,25 +777,25 @@ class DestroyHandler(toolbox.Bookkeeper):
#print 'DH IMPORT', app, id(app), id(self), len(self.debug_all_apps)
# If it's a destructive op, add it to our watch list
if getattr(app.op, 'destroy_map', {}):
if getattr(app.op, 'destroy_map', OrderedDict()):
self.destroyers.add(app)
# add this symbol to the forward and backward maps
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
if len(i_idx_list) > 1:
raise NotImplementedError('destroying this output invalidates multiple inputs', (app.op))
o = app.outputs[o_idx]
i = app.inputs[i_idx_list[0]]
self.view_i[o] = i
self.view_o.setdefault(i,set()).add(o)
self.view_o.setdefault(i, OrderedSet()).add(o)
# update self.clients
for i, input in enumerate(app.inputs):
self.clients.setdefault(input, {}).setdefault(app,0)
self.clients.setdefault(input, OrderedDict()).setdefault(app,0)
self.clients[input][app] += 1
for i, output in enumerate(app.outputs):
self.clients.setdefault(output, {})
self.clients.setdefault(output, OrderedDict())
self.stale_droot = True
......@@ -803,10 +805,10 @@ class DestroyHandler(toolbox.Bookkeeper):
self.debug_all_apps.remove(app)
#UPDATE self.clients
for i, input in enumerate(set(app.inputs)):
for i, input in enumerate(OrderedSet(app.inputs)):
del self.clients[input][app]
if getattr(app.op, 'destroy_map', {}):
if getattr(app.op, 'destroy_map', OrderedDict()):
self.destroyers.remove(app)
# Note: leaving empty client dictionaries in the struct.
......@@ -814,7 +816,7 @@ class DestroyHandler(toolbox.Bookkeeper):
# deleted on_detach().
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
......@@ -843,11 +845,11 @@ class DestroyHandler(toolbox.Bookkeeper):
if self.clients[old_r][app] == 0:
del self.clients[old_r][app]
self.clients.setdefault(new_r,{}).setdefault(app,0)
self.clients.setdefault(new_r, OrderedDict()).setdefault(app,0)
self.clients[new_r][app] += 1
#UPDATE self.view_i, self.view_o
for o_idx, i_idx_list in getattr(app.op, 'view_map', {}).items():
for o_idx, i_idx_list in getattr(app.op, 'view_map', OrderedDict()).items():
if len(i_idx_list) > 1:
#destroying this output invalidates multiple inputs
raise NotImplementedError()
......@@ -863,7 +865,7 @@ class DestroyHandler(toolbox.Bookkeeper):
if not self.view_o[old_r]:
del self.view_o[old_r]
self.view_o.setdefault(new_r,set()).add(output)
self.view_o.setdefault(new_r, OrderedSet()).add(output)
self.stale_droot = True
......@@ -896,7 +898,7 @@ class DestroyHandler(toolbox.Bookkeeper):
c) an Apply destroys (illegally) one of its own inputs by aliasing
"""
rval = {}
rval = OrderedDict()
if self.destroyers:
# BUILD DATA STRUCTURES
......@@ -956,11 +958,11 @@ class DestroyHandler(toolbox.Bookkeeper):
#CHECK FOR INPUT ALIASING
# OPT: pre-compute this on import
tolerate_same = getattr(app.op, 'destroyhandler_tolerate_same', [])
tolerated = set(idx1 for idx0, idx1 in tolerate_same
tolerated = OrderedSet(idx1 for idx0, idx1 in tolerate_same
if idx0 == destroyed_idx)
tolerated.add(destroyed_idx)
tolerate_aliased = getattr(app.op, 'destroyhandler_tolerate_aliased', [])
ignored = set(idx1 for idx0, idx1 in tolerate_aliased
ignored = OrderedSet(idx1 for idx0, idx1 in tolerate_aliased
if idx0 == destroyed_idx)
#print 'tolerated', tolerated
#print 'ignored', ignored
......@@ -974,7 +976,7 @@ class DestroyHandler(toolbox.Bookkeeper):
# add the rule: app must be preceded by all other Apply instances that
# depend on destroyed_input
root_clients = set()
root_clients = OrderedSet()
for r in root_impact:
assert not [a for a,c in self.clients[r].items() if not c]
root_clients.update([a for a,c in self.clients[r].items() if c])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论