Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
0c11332c
提交
0c11332c
authored
1月 06, 2015
作者:
Frédéric Bastien
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #2231 from julianser/Fix_2071
Fixed #2071: Moving test_record.py from Pylearn2 to Theano.
上级
9048d63f
cb682122
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
298 行增加
和
11 行删除
+298
-11
record.py
theano/tests/record.py
+135
-11
test_record.py
theano/tests/test_record.py
+163
-0
没有找到文件。
theano/tests/record.py
浏览文件 @
0c11332c
...
@@ -8,6 +8,7 @@ from theano.compile import Mode
...
@@ -8,6 +8,7 @@ from theano.compile import Mode
import
theano
import
theano
from
theano.printing
import
hex_digest
from
theano.printing
import
hex_digest
class
MismatchError
(
Exception
):
class
MismatchError
(
Exception
):
"""
"""
Raised by Record.handle_line when the
Raised by Record.handle_line when the
...
@@ -15,8 +16,52 @@ class MismatchError(Exception):
...
@@ -15,8 +16,52 @@ class MismatchError(Exception):
of a record.
of a record.
"""
"""
class
Record
(
object
):
class
Record
(
object
):
"""
Records a sequence of strings (from a string buffer). These can then be
compared to another sequence of strings, and if the two sequences don't
match a mismatch exception is raised.
Example:
# Create a Record object and store 'hello world' inside it
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
recorder.handle_line('hello world
\n
')
# Store the previous output
output_value = output.getvalue()
output = cStringIO.StringIO(output_value)
# Create another Record object, now in playback mode, and set
# it to the previous sequence of strings
playback_checker = Record(file_object=output, replay=True)
# Check if it matches the previous one
playback_checker.handle_line('hello world
\n
')
# Now check if it the next item matches something else. This will
# throw an exception because there is no next item
playback_checker.handle_line('hello new world
\n
')
"""
def
__init__
(
self
,
file_object
=
None
,
file_path
=
None
,
replay
=
False
):
def
__init__
(
self
,
file_object
=
None
,
file_path
=
None
,
replay
=
False
):
"""
Initializes Record object to use file on disc and whether it is in
replay mode or not.
Parameters
----------
file_object : StringIO
The input string buffer.
file_path : string, optional
File to save Record to.
replay : bool, optional
Determines whether or not the object is in playback mode. If not
in playback mode, the content of record will be written to the
file. If in playback mode, the content of file is loaded into the
record.
"""
assert
file_object
is
not
None
or
file_path
is
not
None
assert
file_object
is
not
None
or
file_path
is
not
None
...
@@ -30,6 +75,18 @@ class Record(object):
...
@@ -30,6 +75,18 @@ class Record(object):
self
.
__dict__
.
update
(
locals
())
self
.
__dict__
.
update
(
locals
())
def
handle_line
(
self
,
line
):
def
handle_line
(
self
,
line
):
"""
If not in playback mode, it records a new string. If in playback mode,
it compares the current string to the next element in the sequence.
If these are identical the element is removed and otherwise a mismatch
exception is raised.
Parameters
----------
line : string
The string to record.
"""
assert
line
.
endswith
(
'
\n
'
)
assert
line
.
endswith
(
'
\n
'
)
assert
line
[:
-
2
]
.
find
(
'
\n
'
)
==
-
1
assert
line
[:
-
2
]
.
find
(
'
\n
'
)
==
-
1
if
self
.
replay
:
if
self
.
replay
:
...
@@ -50,20 +107,67 @@ class Record(object):
...
@@ -50,20 +107,67 @@ class Record(object):
else
:
else
:
self
.
f
.
write
(
line
)
self
.
f
.
write
(
line
)
class
RecordMode
(
Mode
):
class
RecordMode
(
Mode
):
"""
"""
Records all computations done with a function in a file at output_path
Records all computations done with a function in a file at output_path.
Prints the index of each apply node and md5 digests of the numpy ndarrays
Writes into the file the index of each apply node and md5 digests of the
it receives as inputs and produces as outputs.
numpy ndarrays it receives as inputs and produces as output.
Example:
# We use RecordMode to test that the computation of a function is
identical. Create a Record object and use it to initialize a
RecordMode object.
output = cStringIO.StringIO()
record = Record(file_object=output, replay=False)
record_mode = RecordMode(record)
# Then compile and call the function you wish to test, which uses
# Apply nodes with record_mode as first parameter to record all the
# computations to file. For example, call a Theano function with the
# RecordMode object.
x = theano.tensor.dscalar()
f = theano.function([x], 2*x, mode=record_mode)
print f(4)
# Create another RecordMode object and initialize it with the previous
# record.
output = cStringIO.StringIO(output.getvalue())
playback = Record(file_object=output, replay=True)
playback_mode = RecordMode(playback)
# Compile and call the function to test again with record_mode as
# first parameter. An exception will be thrown if the recorded
# computations are not identical between the two runs.
x = theano.tensor.dscalar()
f = theano.function([x], 2*x, mode=playback_mode)
print f(4)
"""
"""
def
set_record
(
self
,
record
):
def
set_record
(
self
,
record
):
"""
Configure object to use an existing Record object.
Parameters
----------
record : Record
The Record object to use.
"""
self
.
record
=
record
self
.
record
=
record
self
.
known_fgraphs
=
set
([])
self
.
known_fgraphs
=
set
([])
def
__init__
(
self
,
record
=
None
,
**
kwargs
):
def
__init__
(
self
,
record
=
None
,
**
kwargs
):
"""
"""
Takes either a Record object or the keyword arguments to make one.
Takes either a Record object or the keyword arguments to make one.
Parameters
----------
record : Record
The existing Record object to use.
kwargs : pointer?
Keyword arguments to construct new object.
"""
"""
if
record
is
None
:
if
record
is
None
:
...
@@ -73,15 +177,28 @@ class RecordMode(Mode):
...
@@ -73,15 +177,28 @@ class RecordMode(Mode):
self
.
set_record
(
record
)
self
.
set_record
(
record
)
def
handle_line
(
line
,
i
,
node
,
fn
):
def
handle_line
(
line
,
i
,
node
,
fn
):
"""
Records new node computation.
Parameters
----------
line : string
Line to record. For example, the function name or node name.
i : integer
Node number in the toposort order.
node : Apply,
The Apply node which created the entry.
fn : Function,
Function related to Apply node.
"""
try
:
try
:
self
.
record
.
handle_line
(
line
)
self
.
record
.
handle_line
(
line
)
except
MismatchError
,
e
:
except
MismatchError
,
e
:
print
'Got this MismatchError:'
print
'Got this MismatchError:'
print
e
print
e
print
'while processing node i='
+
str
(
i
)
+
':'
print
'while processing node i='
+
str
(
i
)
+
':'
print
'str(node):'
,
str
(
node
)
print
'str(node):'
,
str
(
node
)
print
'Symbolic inputs: '
print
'Symbolic inputs: '
for
elem
in
node
.
inputs
:
for
elem
in
node
.
inputs
:
print
theano
.
printing
.
min_informative_str
(
elem
)
print
theano
.
printing
.
min_informative_str
(
elem
)
...
@@ -94,26 +211,33 @@ class RecordMode(Mode):
...
@@ -94,26 +211,33 @@ class RecordMode(Mode):
raise
MismatchError
(
"Non-determinism detected by WrapLinker"
)
raise
MismatchError
(
"Non-determinism detected by WrapLinker"
)
def
callback
(
i
,
node
,
fn
):
def
callback
(
i
,
node
,
fn
):
"""
Function called by Apply nodes at the end of each computation?
"""
fgraph
=
node
.
fgraph
fgraph
=
node
.
fgraph
if
fgraph
.
name
is
None
:
if
fgraph
.
name
is
None
:
raise
ValueError
(
"Un-named functions are not allowed with RecordMode, "
raise
ValueError
(
"Un-named functions are not allowed with RecordMode, "
"because they make it impossible to tell if the same function is "
"because they make it impossible to tell if the same function is "
"running during the playback."
)
"running during the playback."
)
if
fgraph
not
in
self
.
known_fgraphs
:
if
fgraph
not
in
self
.
known_fgraphs
:
assert
not
any
([
elem
.
name
==
fgraph
.
name
for
elem
in
self
.
known_fgraphs
])
assert
not
any
([
elem
.
name
==
fgraph
.
name
for
elem
in
self
.
known_fgraphs
])
self
.
known_fgraphs
.
add
(
fgraph
)
self
.
known_fgraphs
.
add
(
fgraph
)
num_app
=
len
(
fgraph
.
apply_nodes
)
num_app
=
len
(
fgraph
.
apply_nodes
)
line
=
'Function '
+
fgraph
.
name
+
' has '
+
str
(
num_app
)
+
' apply nodes.
\n
'
line
=
'Function '
+
fgraph
.
name
+
' has '
+
str
(
num_app
)
\
+
' apply nodes.
\n
'
handle_line
(
line
,
i
,
node
,
fn
)
handle_line
(
line
,
i
,
node
,
fn
)
line
=
'Function name: '
+
fgraph
.
name
+
'
\n
'
line
=
'Function name: '
+
fgraph
.
name
+
'
\n
'
handle_line
(
line
,
i
,
node
,
fn
)
handle_line
(
line
,
i
,
node
,
fn
)
line
=
'Node '
+
str
(
i
)
+
':'
+
str
(
node
)
+
'
\n
'
line
=
'Node '
+
str
(
i
)
+
':'
+
str
(
node
)
+
'
\n
'
handle_line
(
line
,
i
,
node
,
fn
)
handle_line
(
line
,
i
,
node
,
fn
)
assert
all
([
isinstance
(
x
,
list
)
and
len
(
x
)
==
1
for
x
in
fn
.
inputs
])
assert
all
([
isinstance
(
x
,
list
)
and
len
(
x
)
==
1
for
x
in
fn
.
inputs
])
def
digest
(
x
):
def
digest
(
x
):
x
=
x
[
0
]
x
=
x
[
0
]
return
hex_digest
(
x
)
return
hex_digest
(
x
)
...
@@ -125,7 +249,7 @@ class RecordMode(Mode):
...
@@ -125,7 +249,7 @@ class RecordMode(Mode):
line
=
'Outputs: '
+
outputs_digest
+
'
\n
'
line
=
'Outputs: '
+
outputs_digest
+
'
\n
'
handle_line
(
line
,
i
,
node
,
fn
)
handle_line
(
line
,
i
,
node
,
fn
)
#linker = theano.gof.OpWiseCLinker()
#
linker = theano.gof.OpWiseCLinker()
linker
=
theano
.
gof
.
vm
.
VM_Linker
(
use_cloop
=
bool
(
theano
.
config
.
cxx
))
linker
=
theano
.
gof
.
vm
.
VM_Linker
(
use_cloop
=
bool
(
theano
.
config
.
cxx
))
wrap_linker
=
theano
.
gof
.
WrapLinkerMany
([
linker
],
[
callback
])
wrap_linker
=
theano
.
gof
.
WrapLinkerMany
([
linker
],
[
callback
])
...
...
theano/tests/test_record.py
0 → 100644
浏览文件 @
0c11332c
from
theano.tests.record
import
*
from
theano
import
function
from
theano.tensor
import
iscalar
import
cStringIO
def
test_record_good
():
"""
Tests that when we record a sequence of events, then
repeat it exactly, the Record class:
1) Records it correctly
2) Does not raise any errors
"""
# Record a sequence of events
output
=
cStringIO
.
StringIO
()
recorder
=
Record
(
file_object
=
output
,
replay
=
False
)
num_lines
=
10
for
i
in
xrange
(
num_lines
):
recorder
.
handle_line
(
str
(
i
)
+
'
\n
'
)
# Make sure they were recorded correctly
output_value
=
output
.
getvalue
()
assert
output_value
==
''
.
join
(
str
(
i
)
+
'
\n
'
for
i
in
xrange
(
num_lines
))
# Make sure that the playback functionality doesn't raise any errors
# when we repeat them
output
=
cStringIO
.
StringIO
(
output_value
)
playback_checker
=
Record
(
file_object
=
output
,
replay
=
True
)
for
i
in
xrange
(
num_lines
):
playback_checker
.
handle_line
(
str
(
i
)
+
'
\n
'
)
def
test_record_bad
():
"""
Tests that when we record a sequence of events, then
do something different on playback, the Record class catches it.
"""
# Record a sequence of events
output
=
cStringIO
.
StringIO
()
recorder
=
Record
(
file_object
=
output
,
replay
=
False
)
num_lines
=
10
for
i
in
xrange
(
num_lines
):
recorder
.
handle_line
(
str
(
i
)
+
'
\n
'
)
# Make sure that the playback functionality doesn't raise any errors
# when we repeat some of them
output_value
=
output
.
getvalue
()
output
=
cStringIO
.
StringIO
(
output_value
)
playback_checker
=
Record
(
file_object
=
output
,
replay
=
True
)
for
i
in
xrange
(
num_lines
//
2
):
playback_checker
.
handle_line
(
str
(
i
)
+
'
\n
'
)
# Make sure it raises an error when we deviate from the recorded sequence
try
:
playback_checker
.
handle_line
(
'0
\n
'
)
except
MismatchError
:
return
raise
AssertionError
(
"Failed to detect mismatch between recorded sequence "
" and repetition of it."
)
def
test_record_mode_good
():
"""
Like test_record_good, but some events are recorded by the
theano RecordMode. We don't attempt to check the
exact string value of the record in this case.
"""
# Record a sequence of events
output
=
cStringIO
.
StringIO
()
recorder
=
Record
(
file_object
=
output
,
replay
=
False
)
record_mode
=
RecordMode
(
recorder
)
i
=
iscalar
()
f
=
function
([
i
],
i
,
mode
=
record_mode
,
name
=
'f'
)
num_lines
=
10
for
i
in
xrange
(
num_lines
):
recorder
.
handle_line
(
str
(
i
)
+
'
\n
'
)
f
(
i
)
# Make sure that the playback functionality doesn't raise any errors
# when we repeat them
output_value
=
output
.
getvalue
()
output
=
cStringIO
.
StringIO
(
output_value
)
playback_checker
=
Record
(
file_object
=
output
,
replay
=
True
)
playback_mode
=
RecordMode
(
playback_checker
)
i
=
iscalar
()
f
=
function
([
i
],
i
,
mode
=
playback_mode
,
name
=
'f'
)
for
i
in
xrange
(
num_lines
):
playback_checker
.
handle_line
(
str
(
i
)
+
'
\n
'
)
f
(
i
)
def
test_record_mode_bad
():
"""
Like test_record_bad, but some events are recorded by the
theano RecordMode, as is the event that triggers the mismatch
error.
"""
# Record a sequence of events
output
=
cStringIO
.
StringIO
()
recorder
=
Record
(
file_object
=
output
,
replay
=
False
)
record_mode
=
RecordMode
(
recorder
)
i
=
iscalar
()
f
=
function
([
i
],
i
,
mode
=
record_mode
,
name
=
'f'
)
num_lines
=
10
for
i
in
xrange
(
num_lines
):
recorder
.
handle_line
(
str
(
i
)
+
'
\n
'
)
f
(
i
)
# Make sure that the playback functionality doesn't raise any errors
# when we repeat them
output_value
=
output
.
getvalue
()
output
=
cStringIO
.
StringIO
(
output_value
)
playback_checker
=
Record
(
file_object
=
output
,
replay
=
True
)
playback_mode
=
RecordMode
(
playback_checker
)
i
=
iscalar
()
f
=
function
([
i
],
i
,
mode
=
playback_mode
,
name
=
'f'
)
for
i
in
xrange
(
num_lines
//
2
):
playback_checker
.
handle_line
(
str
(
i
)
+
'
\n
'
)
f
(
i
)
# Make sure a wrong event causes a MismatchError
try
:
f
(
0
)
except
MismatchError
:
return
raise
AssertionError
(
"Failed to detect a mismatch."
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论