提交 00b36a11 authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Added thorough documentation to Record and RecordMode classes.

上级 f7687b37
...@@ -8,6 +8,7 @@ from theano.compile import Mode ...@@ -8,6 +8,7 @@ from theano.compile import Mode
import theano import theano
from theano.printing import hex_digest from theano.printing import hex_digest
class MismatchError(Exception): class MismatchError(Exception):
""" """
Raised by Record.handle_line when the Raised by Record.handle_line when the
...@@ -15,8 +16,52 @@ class MismatchError(Exception): ...@@ -15,8 +16,52 @@ class MismatchError(Exception):
of a record. of a record.
""" """
class Record(object): class Record(object):
"""
Records a sequence of strings (from a string buffer). These can then be
compared to another sequence of strings, and if the two sequences don't
match a mismatch exception is raised.
Example:
# Create a Record object and store 'hello world' inside it
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
recorder.handle_line('hello world \n')
# Store the previous output
output_value = output.getvalue()
output = cStringIO.StringIO(output_value)
# Create another Record object, now in playback mode, and set
# it to the previous sequence of strings
playback_checker = Record(file_object=output, replay=True)
# Check if it matches the previous one
playback_checker.handle_line('hello world \n')
# Now check if it the next item matches something else. This will
# throw an exception because there is no next item
playback_checker.handle_line('hello new world \n')
"""
def __init__(self, file_object=None, file_path=None, replay=False): def __init__(self, file_object=None, file_path=None, replay=False):
"""
Initializes Record object to use file on disc and whether it is in
replay mode or not.
Parameters
----------
file_object : StringIO
The input string buffer.
file_path : string, optional
File to save Record to.
replay : bool, optional
Determines whether or not the object is in playback mode. If not
in playback mode, the content of record will be written to the
file. If in playback mode, the content of file is loaded into the
record.
"""
assert file_object is not None or file_path is not None assert file_object is not None or file_path is not None
...@@ -30,6 +75,18 @@ class Record(object): ...@@ -30,6 +75,18 @@ class Record(object):
self.__dict__.update(locals()) self.__dict__.update(locals())
def handle_line(self, line): def handle_line(self, line):
"""
If not in playback mode, it records a new string. If in playback mode,
it compares the current string to the next element in the sequence.
If these are identical the element is removed and otherwise a mismatch
exception is raised.
Parameters
----------
line : string
The string to record.
"""
assert line.endswith('\n') assert line.endswith('\n')
assert line[:-2].find('\n') == -1 assert line[:-2].find('\n') == -1
if self.replay: if self.replay:
...@@ -50,20 +107,62 @@ class Record(object): ...@@ -50,20 +107,62 @@ class Record(object):
else: else:
self.f.write(line) self.f.write(line)
class RecordMode(Mode): class RecordMode(Mode):
""" """
Records all computations done with a function in a file at output_path Records all computations done with a function in a file at output_path.
Prints the index of each apply node and md5 digests of the numpy ndarrays Writes into the file the index of each apply node and md5 digests of the
it receives as inputs and produces as outputs. numpy ndarrays it receives as inputs and produces as output.
Example:
# We use RecordMode to test that the computation of a function is
identical. Create a Record object and use it to initialize a
RecordMode object.
output = cStringIO.StringIO()
record = Record(file_object=output, replay=False)
record_mode = RecordMode(record)
# Then call the function you wish to test, which uses Apply nodes,
# with record_mode as first parameter to record all the computations
# to file.
...
# Create a new RecordMode object and initialize it with the previous
# record.
output = cStringIO.StringIO(output.getvalue())
playback = Record(file_object=output, replay=True)
playback_mode = RecordMode(playback)
# Call the function to test again with record_mode as first parameter.
# An exception will be thrown if the recorded computations are not
# identical between the two runs.
...
""" """
def set_record(self, record): def set_record(self, record):
"""
Configure object to use an existing Record object.
Parameters
----------
record : Record
The Record object to use.
"""
self.record = record self.record = record
self.known_fgraphs = set([]) self.known_fgraphs = set([])
def __init__(self, record = None, **kwargs): def __init__(self, record = None, **kwargs):
""" """
Takes either a Record object or the keyword arguments to make one. Takes either a Record object or the keyword arguments to make one.
Parameters
----------
record : Record
The existing Record object to use.
kwargs : pointer?
Keyword arguments to construct new object.
""" """
if record is None: if record is None:
...@@ -73,15 +172,26 @@ class RecordMode(Mode): ...@@ -73,15 +172,26 @@ class RecordMode(Mode):
self.set_record(record) self.set_record(record)
def handle_line(line, i, node, fn): def handle_line(line, i, node, fn):
"""
Records new node computation.
Parameters
----------
line : string
Name of function?
i : integer?
unique id of node?
node : ???
fn : ???
"""
try: try:
self.record.handle_line(line) self.record.handle_line(line)
except MismatchError, e: except MismatchError, e:
print 'Got this MismatchError:' print 'Got this MismatchError:'
print e print e
print 'while processing node i='+str(i)+':' print 'while processing node i='+str(i)+':'
print 'str(node):',str(node) print 'str(node):', str(node)
print 'Symbolic inputs: ' print 'Symbolic inputs: '
for elem in node.inputs: for elem in node.inputs:
print theano.printing.min_informative_str(elem) print theano.printing.min_informative_str(elem)
...@@ -94,6 +204,9 @@ class RecordMode(Mode): ...@@ -94,6 +204,9 @@ class RecordMode(Mode):
raise MismatchError("Non-determinism detected by WrapLinker") raise MismatchError("Non-determinism detected by WrapLinker")
def callback(i, node, fn): def callback(i, node, fn):
"""
Function called by Apply nodes at the end of each computation?
"""
fgraph = node.fgraph fgraph = node.fgraph
...@@ -103,17 +216,21 @@ class RecordMode(Mode): ...@@ -103,17 +216,21 @@ class RecordMode(Mode):
"running during the playback.") "running during the playback.")
if fgraph not in self.known_fgraphs: if fgraph not in self.known_fgraphs:
assert not any([elem.name == fgraph.name for elem in self.known_fgraphs]) assert not any([elem.name == fgraph.name
for elem in self.known_fgraphs])
self.known_fgraphs.add(fgraph) self.known_fgraphs.add(fgraph)
num_app = len(fgraph.apply_nodes) num_app = len(fgraph.apply_nodes)
line = 'Function '+fgraph.name+' has '+str(num_app)+' apply nodes.\n' line = 'Function '+fgraph.name+' has ' + str(num_app) \
+ ' apply nodes.\n'
handle_line(line, i, node, fn) handle_line(line, i, node, fn)
line = 'Function name: '+fgraph.name + '\n' line = 'Function name: '+fgraph.name + '\n'
handle_line(line, i, node, fn) handle_line(line, i, node, fn)
line = 'Node '+str(i)+':'+str(node)+'\n' line = 'Node '+str(i)+':'+str(node)+'\n'
handle_line(line, i, node, fn) handle_line(line, i, node, fn)
assert all([isinstance(x, list) and len(x) == 1 for x in fn.inputs]) assert all([isinstance(x, list)
and len(x) == 1 for x in fn.inputs])
def digest(x): def digest(x):
x = x[0] x = x[0]
return hex_digest(x) return hex_digest(x)
...@@ -125,7 +242,7 @@ class RecordMode(Mode): ...@@ -125,7 +242,7 @@ class RecordMode(Mode):
line = 'Outputs: ' + outputs_digest + '\n' line = 'Outputs: ' + outputs_digest + '\n'
handle_line(line, i, node, fn) handle_line(line, i, node, fn)
#linker = theano.gof.OpWiseCLinker() # linker = theano.gof.OpWiseCLinker()
linker = theano.gof.vm.VM_Linker(use_cloop=bool(theano.config.cxx)) linker = theano.gof.vm.VM_Linker(use_cloop=bool(theano.config.cxx))
wrap_linker = theano.gof.WrapLinkerMany([linker], [callback]) wrap_linker = theano.gof.WrapLinkerMany([linker], [callback])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论