提交 0c11332c authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2231 from julianser/Fix_2071

Fixed #2071: Moving test_record.py from Pylearn2 to Theano.
...@@ -8,6 +8,7 @@ from theano.compile import Mode ...@@ -8,6 +8,7 @@ from theano.compile import Mode
import theano import theano
from theano.printing import hex_digest from theano.printing import hex_digest
class MismatchError(Exception): class MismatchError(Exception):
""" """
Raised by Record.handle_line when the Raised by Record.handle_line when the
...@@ -15,8 +16,52 @@ class MismatchError(Exception): ...@@ -15,8 +16,52 @@ class MismatchError(Exception):
of a record. of a record.
""" """
class Record(object): class Record(object):
"""
Records a sequence of strings (from a string buffer). These can then be
compared to another sequence of strings, and if the two sequences don't
match a mismatch exception is raised.
Example:
# Create a Record object and store 'hello world' inside it
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
recorder.handle_line('hello world \n')
# Store the previous output
output_value = output.getvalue()
output = cStringIO.StringIO(output_value)
# Create another Record object, now in playback mode, and set
# it to the previous sequence of strings
playback_checker = Record(file_object=output, replay=True)
# Check if it matches the previous one
playback_checker.handle_line('hello world \n')
# Now check if it the next item matches something else. This will
# throw an exception because there is no next item
playback_checker.handle_line('hello new world \n')
"""
def __init__(self, file_object=None, file_path=None, replay=False): def __init__(self, file_object=None, file_path=None, replay=False):
"""
Initializes Record object to use file on disc and whether it is in
replay mode or not.
Parameters
----------
file_object : StringIO
The input string buffer.
file_path : string, optional
File to save Record to.
replay : bool, optional
Determines whether or not the object is in playback mode. If not
in playback mode, the content of record will be written to the
file. If in playback mode, the content of file is loaded into the
record.
"""
assert file_object is not None or file_path is not None assert file_object is not None or file_path is not None
...@@ -30,6 +75,18 @@ class Record(object): ...@@ -30,6 +75,18 @@ class Record(object):
self.__dict__.update(locals()) self.__dict__.update(locals())
def handle_line(self, line): def handle_line(self, line):
"""
If not in playback mode, it records a new string. If in playback mode,
it compares the current string to the next element in the sequence.
If these are identical the element is removed and otherwise a mismatch
exception is raised.
Parameters
----------
line : string
The string to record.
"""
assert line.endswith('\n') assert line.endswith('\n')
assert line[:-2].find('\n') == -1 assert line[:-2].find('\n') == -1
if self.replay: if self.replay:
...@@ -50,20 +107,67 @@ class Record(object): ...@@ -50,20 +107,67 @@ class Record(object):
else: else:
self.f.write(line) self.f.write(line)
class RecordMode(Mode): class RecordMode(Mode):
""" """
Records all computations done with a function in a file at output_path Records all computations done with a function in a file at output_path.
Prints the index of each apply node and md5 digests of the numpy ndarrays Writes into the file the index of each apply node and md5 digests of the
it receives as inputs and produces as outputs. numpy ndarrays it receives as inputs and produces as output.
Example:
# We use RecordMode to test that the computation of a function is
identical. Create a Record object and use it to initialize a
RecordMode object.
output = cStringIO.StringIO()
record = Record(file_object=output, replay=False)
record_mode = RecordMode(record)
# Then compile and call the function you wish to test, which uses
# Apply nodes with record_mode as first parameter to record all the
# computations to file. For example, call a Theano function with the
# RecordMode object.
x = theano.tensor.dscalar()
f = theano.function([x], 2*x, mode=record_mode)
print f(4)
# Create another RecordMode object and initialize it with the previous
# record.
output = cStringIO.StringIO(output.getvalue())
playback = Record(file_object=output, replay=True)
playback_mode = RecordMode(playback)
# Compile and call the function to test again with record_mode as
# first parameter. An exception will be thrown if the recorded
# computations are not identical between the two runs.
x = theano.tensor.dscalar()
f = theano.function([x], 2*x, mode=playback_mode)
print f(4)
""" """
def set_record(self, record): def set_record(self, record):
"""
Configure object to use an existing Record object.
Parameters
----------
record : Record
The Record object to use.
"""
self.record = record self.record = record
self.known_fgraphs = set([]) self.known_fgraphs = set([])
def __init__(self, record = None, **kwargs): def __init__(self, record = None, **kwargs):
""" """
Takes either a Record object or the keyword arguments to make one. Takes either a Record object or the keyword arguments to make one.
Parameters
----------
record : Record
The existing Record object to use.
kwargs : pointer?
Keyword arguments to construct new object.
""" """
if record is None: if record is None:
...@@ -73,15 +177,28 @@ class RecordMode(Mode): ...@@ -73,15 +177,28 @@ class RecordMode(Mode):
self.set_record(record) self.set_record(record)
def handle_line(line, i, node, fn): def handle_line(line, i, node, fn):
"""
Records new node computation.
Parameters
----------
line : string
Line to record. For example, the function name or node name.
i : integer
Node number in the toposort order.
node : Apply,
The Apply node which created the entry.
fn : Function,
Function related to Apply node.
"""
try: try:
self.record.handle_line(line) self.record.handle_line(line)
except MismatchError, e: except MismatchError, e:
print 'Got this MismatchError:' print 'Got this MismatchError:'
print e print e
print 'while processing node i='+str(i)+':' print 'while processing node i='+str(i)+':'
print 'str(node):',str(node) print 'str(node):', str(node)
print 'Symbolic inputs: ' print 'Symbolic inputs: '
for elem in node.inputs: for elem in node.inputs:
print theano.printing.min_informative_str(elem) print theano.printing.min_informative_str(elem)
...@@ -94,26 +211,33 @@ class RecordMode(Mode): ...@@ -94,26 +211,33 @@ class RecordMode(Mode):
raise MismatchError("Non-determinism detected by WrapLinker") raise MismatchError("Non-determinism detected by WrapLinker")
def callback(i, node, fn): def callback(i, node, fn):
"""
Function called by Apply nodes at the end of each computation?
"""
fgraph = node.fgraph fgraph = node.fgraph
if fgraph.name is None: if fgraph.name is None:
raise ValueError("Un-named functions are not allowed with RecordMode, " raise ValueError("Un-named functions are not allowed with RecordMode, "
"because they make it impossible to tell if the same function is " "because they make it impossible to tell if the same function is "
"running during the playback.") "running during the playback.")
if fgraph not in self.known_fgraphs: if fgraph not in self.known_fgraphs:
assert not any([elem.name == fgraph.name for elem in self.known_fgraphs]) assert not any([elem.name == fgraph.name
for elem in self.known_fgraphs])
self.known_fgraphs.add(fgraph) self.known_fgraphs.add(fgraph)
num_app = len(fgraph.apply_nodes) num_app = len(fgraph.apply_nodes)
line = 'Function '+fgraph.name+' has '+str(num_app)+' apply nodes.\n' line = 'Function '+fgraph.name+' has ' + str(num_app) \
+ ' apply nodes.\n'
handle_line(line, i, node, fn) handle_line(line, i, node, fn)
line = 'Function name: '+fgraph.name + '\n' line = 'Function name: '+fgraph.name + '\n'
handle_line(line, i, node, fn) handle_line(line, i, node, fn)
line = 'Node '+str(i)+':'+str(node)+'\n' line = 'Node '+str(i)+':'+str(node)+'\n'
handle_line(line, i, node, fn) handle_line(line, i, node, fn)
assert all([isinstance(x, list) and len(x) == 1 for x in fn.inputs]) assert all([isinstance(x, list)
and len(x) == 1 for x in fn.inputs])
def digest(x): def digest(x):
x = x[0] x = x[0]
return hex_digest(x) return hex_digest(x)
...@@ -125,7 +249,7 @@ class RecordMode(Mode): ...@@ -125,7 +249,7 @@ class RecordMode(Mode):
line = 'Outputs: ' + outputs_digest + '\n' line = 'Outputs: ' + outputs_digest + '\n'
handle_line(line, i, node, fn) handle_line(line, i, node, fn)
#linker = theano.gof.OpWiseCLinker() # linker = theano.gof.OpWiseCLinker()
linker = theano.gof.vm.VM_Linker(use_cloop=bool(theano.config.cxx)) linker = theano.gof.vm.VM_Linker(use_cloop=bool(theano.config.cxx))
wrap_linker = theano.gof.WrapLinkerMany([linker], [callback]) wrap_linker = theano.gof.WrapLinkerMany([linker], [callback])
......
from theano.tests.record import *
from theano import function
from theano.tensor import iscalar
import cStringIO
def test_record_good():
"""
Tests that when we record a sequence of events, then
repeat it exactly, the Record class:
1) Records it correctly
2) Does not raise any errors
"""
# Record a sequence of events
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
num_lines = 10
for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n')
# Make sure they were recorded correctly
output_value = output.getvalue()
assert output_value == ''.join(str(i)+'\n' for i in xrange(num_lines))
# Make sure that the playback functionality doesn't raise any errors
# when we repeat them
output = cStringIO.StringIO(output_value)
playback_checker = Record(file_object=output, replay=True)
for i in xrange(num_lines):
playback_checker.handle_line(str(i)+'\n')
def test_record_bad():
"""
Tests that when we record a sequence of events, then
do something different on playback, the Record class catches it.
"""
# Record a sequence of events
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
num_lines = 10
for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n')
# Make sure that the playback functionality doesn't raise any errors
# when we repeat some of them
output_value = output.getvalue()
output = cStringIO.StringIO(output_value)
playback_checker = Record(file_object=output, replay=True)
for i in xrange(num_lines // 2):
playback_checker.handle_line(str(i)+'\n')
# Make sure it raises an error when we deviate from the recorded sequence
try:
playback_checker.handle_line('0\n')
except MismatchError:
return
raise AssertionError("Failed to detect mismatch between recorded sequence "
" and repetition of it.")
def test_record_mode_good():
"""
Like test_record_good, but some events are recorded by the
theano RecordMode. We don't attempt to check the
exact string value of the record in this case.
"""
# Record a sequence of events
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
record_mode = RecordMode(recorder)
i = iscalar()
f = function([i], i, mode=record_mode, name='f')
num_lines = 10
for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n')
f(i)
# Make sure that the playback functionality doesn't raise any errors
# when we repeat them
output_value = output.getvalue()
output = cStringIO.StringIO(output_value)
playback_checker = Record(file_object=output, replay=True)
playback_mode = RecordMode(playback_checker)
i = iscalar()
f = function([i], i, mode=playback_mode, name='f')
for i in xrange(num_lines):
playback_checker.handle_line(str(i)+'\n')
f(i)
def test_record_mode_bad():
"""
Like test_record_bad, but some events are recorded by the
theano RecordMode, as is the event that triggers the mismatch
error.
"""
# Record a sequence of events
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
record_mode = RecordMode(recorder)
i = iscalar()
f = function([i], i, mode=record_mode, name='f')
num_lines = 10
for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n')
f(i)
# Make sure that the playback functionality doesn't raise any errors
# when we repeat them
output_value = output.getvalue()
output = cStringIO.StringIO(output_value)
playback_checker = Record(file_object=output, replay=True)
playback_mode = RecordMode(playback_checker)
i = iscalar()
f = function([i], i, mode=playback_mode, name='f')
for i in xrange(num_lines // 2):
playback_checker.handle_line(str(i)+'\n')
f(i)
# Make sure a wrong event causes a MismatchError
try:
f(0)
except MismatchError:
return
raise AssertionError("Failed to detect a mismatch.")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论