提交 b412823a authored 作者: nouiz's avatar nouiz

Merge pull request #1107 from goodfeli/determinism_2

Ready to merge: Determinism 2
......@@ -12,6 +12,9 @@ from python25 import all
from theano import config
import warnings
NullType = None
import theano
from python25 import OrderedDict
from theano.misc.ordered_set import OrderedSet
class InconsistencyError(Exception):
"""
......@@ -557,6 +560,7 @@ class FunctionGraph(utils.object2):
ords = self.orderings()
order = graph.io_toposort(fg.inputs, fg.outputs, ords)
return order
def orderings(self):
......@@ -571,14 +575,24 @@ class FunctionGraph(utils.object2):
take care of computing dependencies by itself.
"""
ords = {}
ords = OrderedDict()
assert isinstance(self._features, list)
for feature in self._features:
if hasattr(feature, 'orderings'):
for node, prereqs in feature.orderings(self).items():
orderings = feature.orderings(self)
if not isinstance(orderings, OrderedDict):
raise TypeError("Non-deterministic return value from " \
+str(feature.orderings) \
+". Nondeterministic object is "+str(orderings))
for node, prereqs in orderings.items():
if not isinstance(prereqs, (list, OrderedSet)):
raise TypeError("prereqs must be a type with a "
"deterministic iteration order, or toposort "
" will be non-deterministic.")
ords.setdefault(node, []).extend(prereqs)
# eliminate duplicate prereqs
for (node,prereqs) in ords.items():
ords[node] = list(set(prereqs))
ords[node] = list(OrderedSet(prereqs))
return ords
def nclients(self, r):
......
......@@ -15,6 +15,7 @@ import theano
import warnings
from theano.gof import utils
from theano.gof.python25 import any, deque
from theano.misc.ordered_set import OrderedSet
# Lazy imports to avoid circular dependencies.
is_same_graph_with_merge = None
......@@ -736,6 +737,9 @@ def general_toposort(r_out, deps, debug_print=False):
if io not in deps_cache:
d = deps(io)
if d:
if not isinstance(d, (list, OrderedSet)):
raise TypeError("Non-deterministic collections here make"
" toposort non-deterministic.")
deps_cache[io] = list(d)
else:
deps_cache[io] = d
......
......@@ -518,7 +518,8 @@ def add_clear_storage(f, computed, storage_map):
class WrapLinker(Linker):
""" WRITEME
"""
WRITEME
This class makes it easier to run several L{LocalLinker}s in parallel, and
offers some control over how each thunk is run.
......@@ -646,7 +647,8 @@ class WrapLinker(Linker):
return f, inputs0, outputs0
def WrapLinkerMany(linkers, wrappers):
""" Variant on WrapLinker that runs a series of wrapper functions instead of
"""
Variant on WrapLinker that runs a series of wrapper functions instead of
just one.
"""
def wrapper(*args):
......
......@@ -2,10 +2,12 @@ import sys
import time
from theano.gof.python25 import partial
from theano.gof.python25 import OrderedDict
import graph
class AlreadyThere(Exception):
"""Raised by a Feature's on_attach callback method if the FunctionGraph
attempting to attach the feature already has a functionally identical
......@@ -89,7 +91,7 @@ class Feature(object):
If you raise an exception in this function, the state of the graph
might be broken for all intents and purposes.
"""
return {}
return OrderedDict()
class Bookkeeper(Feature):
......
"""
VMs that run Theano graph computations.
A VM is not actually different from a Linker, we just decided
VM was a better name at some point
VM was a better name at some point.
"""
import link
import logging
......
MutableSet = None
try:
from collections import MutableSet
except ImportError:
# Python 2.4
pass
from theano.gof.python25 import OrderedDict
import types
def check_deterministic(iterable):
# Most places where OrderedSet is used, theano interprets any exception
# whatsoever as a problem that an optimization introduced into the graph.
# If I raise a TypeError when the DestoryHandler tries to do something
# non-deterministic, it will just result in optimizations getting ignored.
# So I must use an assert here. In the long term we should fix the rest of
# theano to use exceptions correctly, so that this can be a TypeError.
if iterable is not None:
assert isinstance(iterable, (list, tuple, OrderedSet, types.GeneratorType))
if MutableSet is not None:
# From http://code.activestate.com/recipes/576694/
# Copyright (C) 2009 Raymond Hettinger
# Permission is hereby granted, free of charge, to any person obtaining a
# copy of this software and associated documentation files (the
# "Software"), to deal in the Software without restriction, including
# without limitation the rights to use, copy, modify, merge, publish,
# distribute, sublicense, and/or sell copies of the Software, and to permit
# persons to whom the Software is furnished to do so, subject to the
# following conditions:
# The above copyright notice and this permission notice shall be included
# in all copies or substantial portions of the Software.
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS
# OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF
# MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.
# IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY
# CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT,
# TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE
# SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
KEY, PREV, NEXT = range(3)
class OrderedSet(MutableSet):
# Added by IG-- pre-existing theano code expected sets
# to have this method
def update(self, iterable):
check_deterministic(iterable)
self |= iterable
def __init__(self, iterable=None):
# Checks added by IG
check_deterministic(iterable)
self.end = end = []
end += [None, end, end] # sentinel node for doubly linked list
self.map = {} # key --> [key, prev, next]
if iterable is not None:
self |= iterable
def __len__(self):
return len(self.map)
def __contains__(self, key):
return key in self.map
def add(self, key):
if key not in self.map:
end = self.end
curr = end[PREV]
curr[NEXT] = end[PREV] = self.map[key] = [key, curr, end]
def discard(self, key):
if key in self.map:
key, prev, next = self.map.pop(key)
prev[NEXT] = next
next[PREV] = prev
def __iter__(self):
end = self.end
curr = end[NEXT]
while curr is not end:
yield curr[KEY]
curr = curr[NEXT]
def __reversed__(self):
end = self.end
curr = end[PREV]
while curr is not end:
yield curr[KEY]
curr = curr[PREV]
def pop(self, last=True):
if not self:
raise KeyError('set is empty')
if last:
key = next(reversed(self))
else:
key = next(iter(self))
self.discard(key)
return key
def __repr__(self):
if not self:
return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, list(self))
def __eq__(self, other):
if isinstance(other, OrderedSet):
return len(self) == len(other) and list(self) == list(other)
return set(self) == set(other)
def __del__(self):
self.clear() # remove circular references
else:
# Python 2.4
class OrderedSet(object):
"""
An implementation of OrderedSet based on the keys of
an OrderedDict.
"""
def __init__(self, iterable=None):
self.data = OrderedDict()
if iterable is not None:
self.update(iterable)
def update(self, container):
check_deterministic(container)
for elem in container:
self.add(elem)
def add(self, key):
self.data[key] = None
def __len__(self):
return len(self.data)
def __contains__(self, key):
return key in self.data
def discard(self, key):
if key in self.data:
del self.data[key]
def remove(self, key):
if key in self.data:
del self.data[key]
else:
raise KeyError(key)
def __iter__(self):
return self.data.keys().__iter__()
def __reversed__(self):
return self.data.__reversed__()
def pop(self, last=True):
raise NotImplementedError()
def __eq__(self, other):
return type(self) == type(other) and \
self.data == other.data
def __del__(self):
# Remove circular references
self.data.clear()
if __name__ == '__main__':
print(OrderedSet('abracadaba'))
print(OrderedSet('simsalabim'))
......@@ -12,6 +12,7 @@ import sys
hashlib = None
import numpy
np = numpy
try:
import pydot as pd
......@@ -1111,14 +1112,17 @@ def var_descriptor(obj, _prev_obs=None, _tag_generator=None):
name = '<ndarray:'
name += 'strides=['+','.join(str(stride) for stride in obj.strides)+']'
name += ',digest='+hashlib.md5(obj).hexdigest()+'>'
elif hasattr(obj, 'name') and obj.name is not None:
name = obj.name
elif hasattr(obj, 'owner') and obj.owner is not None:
name = str(obj.owner.op) + '('
name += ','.join(var_descriptor(ipt,
_prev_obs=_prev_obs, _tag_generator=_tag_generator) for ipt
in obj.owner.inputs)
name += ')'
elif hasattr(obj, 'name') and obj.name is not None:
# Only print the name if there is no owner.
# This way adding a name to an intermediate node can't make
# a deeper graph get the same descriptor as a shallower one
name = obj.name
else:
name = str(obj)
if ' at 0x' in name:
......@@ -1144,6 +1148,17 @@ def position_independent_str(obj):
return rval
def hex_digest(x):
"""
Returns a short, mostly hexadecimal hash of a numpy ndarray
"""
assert isinstance(x, np.ndarray)
rval = hashlib.md5(x.tostring()).hexdigest()
# hex digest must be annotated with strides to avoid collisions
# because the buffer interface only exposes the raw data, not
# any info about the semantics of how that data should be arranged
# into a tensor
rval = rval + '|strides=[' + ','.join(str(stride) for stride in x.strides) + ']'
rval = rval + '|shape=[' + ','.join(str(s) for s in x.shape) + ']'
return rval
__authors__ = "Ian Goodfellow"
__copyright__ = "Copyright 2010-2012, Universite de Montreal"
__credits__ = ["Ian Goodfellow"]
__license__ = "3-clause BSD"
__maintainer__ = "Ian Goodfellow"
__email__ = "goodfeli@iro"
from datetime import datetime
def disturb_mem():
# Allocate a time-dependent amount of objects to increase
# chances of subsequently objects' ids changing from run
# to run. This is useful for exposing issues that cause
# non-deterministic behavior due to dependence on memory
# addresses, like iterating over a dict or a set.
global l
now = datetime.now()
ms = now.microsecond
ms = int(ms)
n = ms % 1000
m = ms / 1000
l = [[0]*m for i in xrange(n)]
__authors__ = "Ian Goodfellow"
__copyright__ = "Copyright 2010-2012, Universite de Montreal"
__credits__ = ["Ian Goodfellow"]
__license__ = "3-clause BSD"
__maintainer__ = "Ian Goodfellow"
__email__ = "goodfeli@iro"
from theano.compile import Mode
import theano
from theano.printing import hex_digest
class MismatchError(Exception):
"""
Raised by Record.handle_line when the
current execution doesn't match the replay
of a record.
"""
class Record(object):
def __init__(self, file_object=None, file_path=None, replay=False):
assert file_object is not None or file_path is not None
if replay and file_object is None:
self.f = open(file_path, 'r')
elif (not replay) and file_object is None:
self.f = open(file_path, 'w')
else:
self.f = file_object
self.__dict__.update(locals())
def handle_line(self, line):
assert line.endswith('\n')
assert line[:-2].find('\n') == -1
if self.replay:
old_line = self.f.readline()
if old_line != line:
msg = 'Replay detected mismatch.\n'
msg += ' I wanted to write:\n'
if len(line) > 100:
msg += line[0:100]+'...'
else:
msg += line
msg += '\nwhen previous job wrote:\n'
if len(old_line) > 100:
msg += old_line[0:100]+'...'
else:
msg += old_line
raise MismatchError(msg)
else:
self.f.write(line)
class RecordMode(Mode):
"""
Records all computations done with a function in a file at output_path
Prints the index of each apply node and md5 digests of the numpy ndarrays
it receives as inputs and produces as outputs.
"""
def set_record(self, record):
self.record = record
self.known_fgraphs = set([])
def __init__(self, record = None, **kwargs):
"""
Takes either a Record object or the keyword arguments to make one.
"""
if record is None:
record = Record(**kwargs)
else:
assert len(kwargs.keys()) == 0
self.set_record(record)
def handle_line(line, i, node, fn):
try:
self.record.handle_line(line)
except MismatchError, e:
print 'Got this MismatchError:'
print e
print 'while processing node i='+str(i)+':'
print 'str(node):',str(node)
print 'Symbolic inputs: '
for elem in node.inputs:
print theano.printing.min_informative_str(elem)
print 'str(output) of outputs: '
for elem in fn.outputs:
assert isinstance(elem, list)
elem, = elem
print str(elem)
print 'function name: '+node.fgraph.name
raise MismatchError("Non-determinism detected by WrapLinker")
def callback(i, node, fn):
fgraph = node.fgraph
if fgraph.name is None:
raise ValueError("Un-named functions are not allowed with RecordMode, "
"because they make it impossible to tell if the same function is "
"running during the playback.")
if fgraph not in self.known_fgraphs:
assert not any([elem.name == fgraph.name for elem in self.known_fgraphs])
self.known_fgraphs.add(fgraph)
num_app = len(fgraph.apply_nodes)
line = 'Function '+fgraph.name+' has '+str(num_app)+' apply nodes.\n'
handle_line(line, i, node, fn)
line = 'Function name: '+fgraph.name + '\n'
handle_line(line, i, node, fn)
line = 'Node '+str(i)+':'+str(node)+'\n'
handle_line(line, i, node, fn)
assert all([isinstance(x, list) and len(x) == 1 for x in fn.inputs])
def digest(x):
x = x[0]
return hex_digest(x)
inputs_digest = ' '.join([digest(x) for x in fn.inputs])
line = 'Inputs: ' + inputs_digest + '\n'
handle_line(line, i, node, fn)
fn()
outputs_digest = ' '.join([digest(x) for x in fn.outputs])
line = 'Outputs: ' + outputs_digest + '\n'
handle_line(line, i, node, fn)
#linker = theano.gof.OpWiseCLinker()
linker = theano.gof.vm.VM_Linker(use_cloop=True)
wrap_linker = theano.gof.WrapLinkerMany([linker], [callback])
super(RecordMode, self).__init__(wrap_linker, optimizer='fast_run')
from theano.tests.record import RecordMode
from theano.tests.record import Record
from theano.gof.python25 import OrderedDict
from theano.tests import disturb_mem
import numpy as np
import theano
from theano.printing import var_descriptor
from cStringIO import StringIO
from theano import config
from theano import shared
def sharedX(x, name=None):
x = np.cast[config.floatX](x)
return shared(x, name)
def test_determinism_1():
# Tests that repeatedly running a script that compiles and
# runs a function does exactly the same thing every time it
# is run, even when the memory addresses of the objects involved
# change.
# This specific script is capable of catching a bug where
# FunctionGraph.toposort was non-deterministic.
def run(replay, log = None):
if not replay:
log = StringIO()
else:
log = StringIO(log)
record = Record(replay=replay, file_object=log)
disturb_mem.disturb_mem()
mode = RecordMode(record=record)
b = sharedX(np.zeros((2,)), name='b')
channels = OrderedDict()
disturb_mem.disturb_mem()
v_max = b.max(axis=0)
v_min = b.min(axis=0)
v_range = v_max - v_min
updates = []
for i, val in enumerate([
v_max.max(),
v_max.min(),
v_range.max(),
]):
disturb_mem.disturb_mem()
s = sharedX(0., name='s_'+str(i))
updates.append((s, val))
for var in theano.gof.graph.ancestors(update for var, update in updates):
if var.name is not None and var.name is not 'b':
if var.name[0] != 's' or len(var.name) != 2:
var.name = None
for key in channels:
updates.append((s, channels[key]))
f = theano.function([], mode=mode, updates=updates, on_unused_input='ignore', name='f')
for output in f.maker.fgraph.outputs:
mode.record.handle_line(var_descriptor(output)+'\n')
disturb_mem.disturb_mem()
f()
mode.record.f.flush()
if not replay:
return log.getvalue()
log = run(0)
# Do several trials, since failure doesn't always occur
# (Sometimes you sample the same outcome twice in a row)
for i in xrange(10):
run(1, log)
if __name__ == '__main__':
test_determinism_1()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论