提交 b412823a authored 作者: nouiz's avatar nouiz

Merge pull request #1107 from goodfeli/determinism_2

Ready to merge: Determinism 2
...@@ -12,6 +12,9 @@ from python25 import all ...@@ -12,6 +12,9 @@ from python25 import all
from theano import config from theano import config
import warnings import warnings
NullType = None NullType = None
import theano
from python25 import OrderedDict
from theano.misc.ordered_set import OrderedSet
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
...@@ -557,6 +560,7 @@ class FunctionGraph(utils.object2): ...@@ -557,6 +560,7 @@ class FunctionGraph(utils.object2):
ords = self.orderings() ords = self.orderings()
order = graph.io_toposort(fg.inputs, fg.outputs, ords) order = graph.io_toposort(fg.inputs, fg.outputs, ords)
return order return order
def orderings(self): def orderings(self):
...@@ -571,14 +575,24 @@ class FunctionGraph(utils.object2): ...@@ -571,14 +575,24 @@ class FunctionGraph(utils.object2):
take care of computing dependencies by itself. take care of computing dependencies by itself.
""" """
ords = {} ords = OrderedDict()
assert isinstance(self._features, list)
for feature in self._features: for feature in self._features:
if hasattr(feature, 'orderings'): if hasattr(feature, 'orderings'):
for node, prereqs in feature.orderings(self).items(): orderings = feature.orderings(self)
if not isinstance(orderings, OrderedDict):
raise TypeError("Non-deterministic return value from " \
+str(feature.orderings) \
+". Nondeterministic object is "+str(orderings))
for node, prereqs in orderings.items():
if not isinstance(prereqs, (list, OrderedSet)):
raise TypeError("prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic.")
ords.setdefault(node, []).extend(prereqs) ords.setdefault(node, []).extend(prereqs)
# eliminate duplicate prereqs # eliminate duplicate prereqs
for (node,prereqs) in ords.items(): for (node,prereqs) in ords.items():
ords[node] = list(set(prereqs)) ords[node] = list(OrderedSet(prereqs))
return ords return ords
def nclients(self, r): def nclients(self, r):
......
...@@ -15,6 +15,7 @@ import theano ...@@ -15,6 +15,7 @@ import theano
import warnings import warnings
from theano.gof import utils from theano.gof import utils
from theano.gof.python25 import any, deque from theano.gof.python25 import any, deque
from theano.misc.ordered_set import OrderedSet
# Lazy imports to avoid circular dependencies. # Lazy imports to avoid circular dependencies.
is_same_graph_with_merge = None is_same_graph_with_merge = None
...@@ -736,6 +737,9 @@ def general_toposort(r_out, deps, debug_print=False): ...@@ -736,6 +737,9 @@ def general_toposort(r_out, deps, debug_print=False):
if io not in deps_cache: if io not in deps_cache:
d = deps(io) d = deps(io)
if d: if d:
if not isinstance(d, (list, OrderedSet)):
raise TypeError("Non-deterministic collections here make"
" toposort non-deterministic.")
deps_cache[io] = list(d) deps_cache[io] = list(d)
else: else:
deps_cache[io] = d deps_cache[io] = d
......
...@@ -518,7 +518,8 @@ def add_clear_storage(f, computed, storage_map): ...@@ -518,7 +518,8 @@ def add_clear_storage(f, computed, storage_map):
class WrapLinker(Linker): class WrapLinker(Linker):
""" WRITEME """
WRITEME
This class makes it easier to run several L{LocalLinker}s in parallel, and This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run. offers some control over how each thunk is run.
...@@ -646,7 +647,8 @@ class WrapLinker(Linker): ...@@ -646,7 +647,8 @@ class WrapLinker(Linker):
return f, inputs0, outputs0 return f, inputs0, outputs0
def WrapLinkerMany(linkers, wrappers): def WrapLinkerMany(linkers, wrappers):
""" Variant on WrapLinker that runs a series of wrapper functions instead of """
Variant on WrapLinker that runs a series of wrapper functions instead of
just one. just one.
""" """
def wrapper(*args): def wrapper(*args):
......
...@@ -2,10 +2,12 @@ import sys ...@@ -2,10 +2,12 @@ import sys
import time import time
from theano.gof.python25 import partial from theano.gof.python25 import partial
from theano.gof.python25 import OrderedDict
import graph import graph
class AlreadyThere(Exception): class AlreadyThere(Exception):
"""Raised by a Feature's on_attach callback method if the FunctionGraph """Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical attempting to attach the feature already has a functionally identical
...@@ -89,7 +91,7 @@ class Feature(object): ...@@ -89,7 +91,7 @@ class Feature(object):
If you raise an exception in this function, the state of the graph If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes. might be broken for all intents and purposes.
""" """
return {} return OrderedDict()
class Bookkeeper(Feature): class Bookkeeper(Feature):
......
""" """
VMs that run Theano graph computations. VMs that run Theano graph computations.
A VM is not actually different from a Linker, we just decided A VM is not actually different from a Linker, we just decided
VM was a better name at some point VM was a better name at some point.
""" """
import link import link
import logging import logging
......
MutableSet = None
try:
from collections import MutableSet
except ImportError:
# Python 2.4
pass
from theano.gof.python25 import OrderedDict
import types
def check_deterministic(iterable):
# Most places where OrderedSet is used, theano interprets any exception
# whatsoever as a problem that an optimization introduced into the graph.
# If I raise a TypeError when the DestoryHandler tries to do something
# non-deterministic, it will just result in optimizations getting ignored.
# So I must use an assert here. In the long term we should fix the rest of
# theano to use exceptions correctly, so that this can be a TypeError.
if iterable is not None:
assert isinstance(iterable, (list, tuple, OrderedSet, types.GeneratorType))
if MutableSet is not None:
# From http://code.activestate.com/recipes/576694/
# Copyright (C) 2009 Raymond Hettinger
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the
# following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
KEY, PREV, NEXT = range(3)
class OrderedSet(MutableSet):
# Added by IG-- pre-existing theano code expected sets
# to have this method
def update(self, iterable):
check_deterministic(iterable)
self |= iterable
def __init__(self, iterable=None):
# Checks added by IG
check_deterministic(iterable)
self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
def __len__(self):
return len(self.map)
def __contains__(self, key):
return key in self.map
def add(self, key):
if key not in self.map:
end = self.end
curr = end[PREV]
curr[NEXT] = end[PREV] = self.map[key] = [key, curr, end]
def discard(self, key):
if key in self.map:
key, prev, next = self.map.pop(key)
prev[NEXT] = next
next[PREV] = prev
def __iter__(self):
end = self.end
curr = end[NEXT]
while curr is not end:
yield curr[KEY]
curr = curr[NEXT]
def __reversed__(self):
end = self.end
curr = end[PREV]
while curr is not end:
yield curr[KEY]
curr = curr[PREV]
def pop(self, last=True):
if not self:
raise KeyError('set is empty')
if last:
key = next(reversed(self))
else:
key = next(iter(self))
self.discard(key)
return key
def __repr__(self):
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, list(self))
def __eq__(self, other):
if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)
def __del__(self):
self.clear() # remove circular references
else:
# Python 2.4
class OrderedSet(object):
"""
An implementation of OrderedSet based on the keys of
an OrderedDict.
"""
def __init__(self, iterable=None):
self.data = OrderedDict()
if iterable is not None:
self.update(iterable)
def update(self, container):
check_deterministic(container)
for elem in container:
self.add(elem)
def add(self, key):
self.data[key] = None
def __len__(self):
return len(self.data)
def __contains__(self, key):
return key in self.data
def discard(self, key):
if key in self.data:
del self.data[key]
def remove(self, key):
if key in self.data:
del self.data[key]
else:
raise KeyError(key)
def __iter__(self):
return self.data.keys().__iter__()
def __reversed__(self):
return self.data.__reversed__()
def pop(self, last=True):
raise NotImplementedError()
def __eq__(self, other):
return type(self) == type(other) and \
self.data == other.data
def __del__(self):
# Remove circular references
self.data.clear()
if __name__ == '__main__':
print(OrderedSet('abracadaba'))
print(OrderedSet('simsalabim'))
...@@ -12,6 +12,7 @@ import sys ...@@ -12,6 +12,7 @@ import sys
hashlib = None hashlib = None
import numpy import numpy
np = numpy
try: try:
import pydot as pd import pydot as pd
...@@ -1111,14 +1112,17 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None): ...@@ -1111,14 +1112,17 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
name = '<ndarray:' name = '<ndarray:'
name += 'strides=['+','.join(str(stride) for stride in obj.strides)+']' name += 'strides=['+','.join(str(stride) for stride in obj.strides)+']'
name += ',digest='+hashlib.md5(obj).hexdigest()+'>' name += ',digest='+hashlib.md5(obj).hexdigest()+'>'
elif hasattr(obj, 'name') and obj.name is not None:
name = obj.name
elif hasattr(obj, 'owner') and obj.owner is not None: elif hasattr(obj, 'owner') and obj.owner is not None:
name = str(obj.owner.op) + '(' name = str(obj.owner.op) + '('
name += ','.join(var_descriptor(ipt, name += ','.join(var_descriptor(ipt,
_prev_obs=_prev_obs, _tag_generator=_tag_generator) for ipt _prev_obs=_prev_obs, _tag_generator=_tag_generator) for ipt
in obj.owner.inputs) in obj.owner.inputs)
name += ')' name += ')'
elif hasattr(obj, 'name') and obj.name is not None:
# Only print the name if there is no owner.
# This way adding a name to an intermediate node can't make
# a deeper graph get the same descriptor as a shallower one
name = obj.name
else: else:
name = str(obj) name = str(obj)
if ' at 0x' in name: if ' at 0x' in name:
...@@ -1144,6 +1148,17 @@ def position_independent_str(obj): ...@@ -1144,6 +1148,17 @@ def position_independent_str(obj):
return rval return rval
def hex_digest(x):
"""
Returns a short, mostly hexadecimal hash of a numpy ndarray
"""
assert isinstance(x, np.ndarray)
rval = hashlib.md5(x.tostring()).hexdigest()
# hex digest must be annotated with strides to avoid collisions
# because the buffer interface only exposes the raw data, not
# any info about the semantics of how that data should be arranged
# into a tensor
rval = rval + '|strides=[' + ','.join(str(stride) for stride in x.strides) + ']'
rval = rval + '|shape=[' + ','.join(str(s) for s in x.shape) + ']'
return rval
__authors__ = "Ian Goodfellow"
__copyright__ = "Copyright 2010-2012, Universite de Montreal"
__credits__ = ["Ian Goodfellow"]
__license__ = "3-clause BSD"
__maintainer__ = "Ian Goodfellow"
__email__ = "goodfeli@iro"
from datetime import datetime
def disturb_mem():
# Allocate a time-dependent amount of objects to increase
# chances of subsequently objects' ids changing from run
# to run. This is useful for exposing issues that cause
# non-deterministic behavior due to dependence on memory
# addresses, like iterating over a dict or a set.
global l
now = datetime.now()
ms = now.microsecond
ms = int(ms)
n = ms % 1000
m = ms / 1000
l = [[0]*m for i in xrange(n)]
__authors__ = "Ian Goodfellow"
__copyright__ = "Copyright 2010-2012, Universite de Montreal"
__credits__ = ["Ian Goodfellow"]
__license__ = "3-clause BSD"
__maintainer__ = "Ian Goodfellow"
__email__ = "goodfeli@iro"
from theano.compile import Mode
import theano
from theano.printing import hex_digest
class MismatchError(Exception):
"""
Raised by Record.handle_line when the
current execution doesn't match the replay
of a record.
"""
class Record(object):
def __init__(self, file_object=None, file_path=None, replay=False):
assert file_object is not None or file_path is not None
if replay and file_object is None:
self.f = open(file_path, 'r')
elif (not replay) and file_object is None:
self.f = open(file_path, 'w')
else:
self.f = file_object
self.__dict__.update(locals())
def handle_line(self, line):
assert line.endswith('\n')
assert line[:-2].find('\n') == -1
if self.replay:
old_line = self.f.readline()
if old_line != line:
msg = 'Replay detected mismatch.\n'
msg += ' I wanted to write:\n'
if len(line) > 100:
msg += line[0:100]+'...'
else:
msg += line
msg += '\nwhen previous job wrote:\n'
if len(old_line) > 100:
msg += old_line[0:100]+'...'
else:
msg += old_line
raise MismatchError(msg)
else:
self.f.write(line)
class RecordMode(Mode):
"""
Records all computations done with a function in a file at output_path
Prints the index of each apply node and md5 digests of the numpy ndarrays
it receives as inputs and produces as outputs.
"""
def set_record(self, record):
self.record = record
self.known_fgraphs = set([])
def __init__(self, record = None, **kwargs):
"""
Takes either a Record object or the keyword arguments to make one.
"""
if record is None:
record = Record(**kwargs)
else:
assert len(kwargs.keys()) == 0
self.set_record(record)
def handle_line(line, i, node, fn):
try:
self.record.handle_line(line)
except MismatchError, e:
print 'Got this MismatchError:'
print e
print 'while processing node i='+str(i)+':'
print 'str(node):',str(node)
print 'Symbolic inputs: '
for elem in node.inputs:
print theano.printing.min_informative_str(elem)
print 'str(output) of outputs: '
for elem in fn.outputs:
assert isinstance(elem, list)
elem, = elem
print str(elem)
print 'function name: '+node.fgraph.name
raise MismatchError("Non-determinism detected by WrapLinker")
def callback(i, node, fn):
fgraph = node.fgraph
if fgraph.name is None:
raise ValueError("Un-named functions are not allowed with RecordMode, "
"because they make it impossible to tell if the same function is "
"running during the playback.")
if fgraph not in self.known_fgraphs:
assert not any([elem.name == fgraph.name for elem in self.known_fgraphs])
self.known_fgraphs.add(fgraph)
num_app = len(fgraph.apply_nodes)
line = 'Function '+fgraph.name+' has '+str(num_app)+' apply nodes.\n'
handle_line(line, i, node, fn)
line = 'Function name: '+fgraph.name + '\n'
handle_line(line, i, node, fn)
line = 'Node '+str(i)+':'+str(node)+'\n'
handle_line(line, i, node, fn)
assert all([isinstance(x, list) and len(x) == 1 for x in fn.inputs])
def digest(x):
x = x[0]
return hex_digest(x)
inputs_digest = ' '.join([digest(x) for x in fn.inputs])
line = 'Inputs: ' + inputs_digest + '\n'
handle_line(line, i, node, fn)
fn()
outputs_digest = ' '.join([digest(x) for x in fn.outputs])
line = 'Outputs: ' + outputs_digest + '\n'
handle_line(line, i, node, fn)
#linker = theano.gof.OpWiseCLinker()
linker = theano.gof.vm.VM_Linker(use_cloop=True)
wrap_linker = theano.gof.WrapLinkerMany([linker], [callback])
super(RecordMode, self).__init__(wrap_linker, optimizer='fast_run')
from theano.tests.record import RecordMode
from theano.tests.record import Record
from theano.gof.python25 import OrderedDict
from theano.tests import disturb_mem
import numpy as np
import theano
from theano.printing import var_descriptor
from cStringIO import StringIO
from theano import config
from theano import shared
def sharedX(x, name=None):
x = np.cast[config.floatX](x)
return shared(x, name)
def test_determinism_1():
# Tests that repeatedly running a script that compiles and
# runs a function does exactly the same thing every time it
# is run, even when the memory addresses of the objects involved
# change.
# This specific script is capable of catching a bug where
# FunctionGraph.toposort was non-deterministic.
def run(replay, log = None):
if not replay:
log = StringIO()
else:
log = StringIO(log)
record = Record(replay=replay, file_object=log)
disturb_mem.disturb_mem()
mode = RecordMode(record=record)
b = sharedX(np.zeros((2,)), name='b')
channels = OrderedDict()
disturb_mem.disturb_mem()
v_max = b.max(axis=0)
v_min = b.min(axis=0)
v_range = v_max - v_min
updates = []
for i, val in enumerate([
v_max.max(),
v_max.min(),
v_range.max(),
]):
disturb_mem.disturb_mem()
s = sharedX(0., name='s_'+str(i))
updates.append((s, val))
for var in theano.gof.graph.ancestors(update for var, update in updates):
if var.name is not None and var.name is not 'b':
if var.name[0] != 's' or len(var.name) != 2:
var.name = None
for key in channels:
updates.append((s, channels[key]))
f = theano.function([], mode=mode, updates=updates, on_unused_input='ignore', name='f')
for output in f.maker.fgraph.outputs:
mode.record.handle_line(var_descriptor(output)+'\n')
disturb_mem.disturb_mem()
f()
mode.record.f.flush()
if not replay:
return log.getvalue()
log = run(0)
# Do several trials, since failure doesn't always occur
# (Sometimes you sample the same outcome twice in a row)
for i in xrange(10):
run(1, log)
if __name__ == '__main__':
test_determinism_1()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论