提交 be890125 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made functiongraph more deterministic

上级 42116a64
......@@ -12,6 +12,9 @@ from python25 import all
from theano import config
import warnings
NullType = None
import theano
from python25 import OrderedDict
from theano.misc.ordered_set import OrderedSet
class InconsistencyError(Exception):
"""
......@@ -553,10 +556,44 @@ class FunctionGraph(utils.object2):
# 1-element graphs.
return list(self.apply_nodes)
fg = self
if verbose:
print 'FunctionGraph.toposort multinode'
# DEBUGGING NOTE: this might be the source of non-determinism.
# it definitely contains dicts and lists made
# by iterating over dicts.
# does it affect the final sort order though?
ords = self.orderings()
if verbose:
"""
print 'FunctionGraph inputs'
assert isinstance(fg.inputs, list)
for i, elem in enumerate(fg.inputs):
print '\t%d:' % i
print theano.printing.min_informative_str(elem, indent_level=1)
print 'FunctionGraph outputs'
assert isinstance(fg.outputs, list)
for i, elem in enumerate(fg.outputs):
print '\t%d:' % i
print theano.printing.min_informative_str(elem, indent_level=1)
"""
for i, key in enumerate(ords):
print 'orderings',i
print '\tkey',i,'is',key
v = ords[key]
print '\t\tvalue',i,'is a',type(v)
for j, elem in enumerate(v):
print '\t\t\telem',j,':',elem
order = graph.io_toposort(fg.inputs, fg.outputs, ords)
if verbose:
print 'FunctionGraph.toposort returning order:'
for i, elem in enumerate(order):
print '\t%d:'%i,elem
return order
def orderings(self):
......@@ -571,14 +608,19 @@ class FunctionGraph(utils.object2):
take care of computing dependencies by itself.
"""
ords = {}
ords = OrderedDict()
assert isinstance(self._features, list)
for feature in self._features:
if hasattr(feature, 'orderings'):
for node, prereqs in feature.orderings(self).items():
orderings = feature.orderings(self)
if not isinstance(orderings, OrderedDict):
raise TypeError("Non-deterministic return value from " \
+str(feature.orderings) \
+". Nondeterministic object is "+str(orderings))
ords.setdefault(node, []).extend(prereqs)
# eliminate duplicate prereqs
for (node,prereqs) in ords.items():
ords[node] = list(set(prereqs))
ords[node] = list(OrderedSet(prereqs))
return ords
def nclients(self, r):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论