提交 f7687b37 authored 作者: Iulian Vlad Serban's avatar Iulian Vlad Serban

Fixed #2071: Moving test_record.py from Pylearn2 to Theano.

上级 7fb90052
from theano.tests.record import *
from theano import function
from theano.tensor import iscalar
import cStringIO
def test_record_good():
"""
Tests that when we record a sequence of events, then
repeat it exactly, the Record class:
1) Records it correctly
2) Does not raise any errors
"""
# Record a sequence of events
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
num_lines = 10
for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n')
# Make sure they were recorded correctly
output_value = output.getvalue()
assert output_value == ''.join(str(i)+'\n' for i in xrange(num_lines))
# Make sure that the playback functionality doesn't raise any errors
# when we repeat them
output = cStringIO.StringIO(output_value)
playback_checker = Record(file_object=output, replay=True)
for i in xrange(num_lines):
playback_checker.handle_line(str(i)+'\n')
def test_record_bad():
"""
Tests that when we record a sequence of events, then
do something different on playback, the Record class catches it.
"""
# Record a sequence of events
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
num_lines = 10
for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n')
# Make sure that the playback functionality doesn't raise any errors
# when we repeat some of them
output_value = output.getvalue()
output = cStringIO.StringIO(output_value)
playback_checker = Record(file_object=output, replay=True)
for i in xrange(num_lines // 2):
playback_checker.handle_line(str(i)+'\n')
# Make sure it raises an error when we deviate from the recorded sequence
try:
playback_checker.handle_line('0\n')
except MismatchError:
return
raise AssertionError("Failed to detect mismatch between recorded sequence "
" and repetition of it.")
def test_record_mode_good():
"""
Like test_record_good, but some events are recorded by the
theano RecordMode. We don't attempt to check the
exact string value of the record in this case.
"""
# Record a sequence of events
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
record_mode = RecordMode(recorder)
i = iscalar()
f = function([i], i, mode=record_mode, name='f')
num_lines = 10
for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n')
f(i)
# Make sure that the playback functionality doesn't raise any errors
# when we repeat them
output_value = output.getvalue()
output = cStringIO.StringIO(output_value)
playback_checker = Record(file_object=output, replay=True)
playback_mode = RecordMode(playback_checker)
i = iscalar()
f = function([i], i, mode=playback_mode, name='f')
for i in xrange(num_lines):
playback_checker.handle_line(str(i)+'\n')
f(i)
def test_record_mode_bad():
"""
Like test_record_bad, but some events are recorded by the
theano RecordMode, as is the event that triggers the mismatch
error.
"""
# Record a sequence of events
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
record_mode = RecordMode(recorder)
i = iscalar()
f = function([i], i, mode=record_mode, name='f')
num_lines = 10
for i in xrange(num_lines):
recorder.handle_line(str(i)+'\n')
f(i)
# Make sure that the playback functionality doesn't raise any errors
# when we repeat them
output_value = output.getvalue()
output = cStringIO.StringIO(output_value)
playback_checker = Record(file_object=output, replay=True)
playback_mode = RecordMode(playback_checker)
i = iscalar()
f = function([i], i, mode=playback_mode, name='f')
for i in xrange(num_lines // 2):
playback_checker.handle_line(str(i)+'\n')
f(i)
# Make sure a wrong event causes a MismatchError
try:
f(0)
except MismatchError:
return
raise AssertionError("Failed to detect a mismatch.")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论