Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
00b36a11
提交
00b36a11
authored
11月 05, 2014
作者:
Iulian Vlad Serban
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Added thorough documentation to Record and RecordMode classes.
上级
f7687b37
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
128 行增加
和
11 行删除
+128
-11
record.py
theano/tests/record.py
+128
-11
没有找到文件。
theano/tests/record.py
浏览文件 @
00b36a11
...
...
@@ -8,6 +8,7 @@ from theano.compile import Mode
import
theano
from
theano.printing
import
hex_digest
class
MismatchError
(
Exception
):
"""
Raised by Record.handle_line when the
...
...
@@ -15,8 +16,52 @@ class MismatchError(Exception):
of a record.
"""
class
Record
(
object
):
"""
Records a sequence of strings (from a string buffer). These can then be
compared to another sequence of strings, and if the two sequences don't
match a mismatch exception is raised.
Example:
# Create a Record object and store 'hello world' inside it
output = cStringIO.StringIO()
recorder = Record(file_object=output, replay=False)
recorder.handle_line('hello world
\n
')
# Store the previous output
output_value = output.getvalue()
output = cStringIO.StringIO(output_value)
# Create another Record object, now in playback mode, and set
# it to the previous sequence of strings
playback_checker = Record(file_object=output, replay=True)
# Check if it matches the previous one
playback_checker.handle_line('hello world
\n
')
# Now check if it the next item matches something else. This will
# throw an exception because there is no next item
playback_checker.handle_line('hello new world
\n
')
"""
def
__init__
(
self
,
file_object
=
None
,
file_path
=
None
,
replay
=
False
):
"""
Initializes Record object to use file on disc and whether it is in
replay mode or not.
Parameters
----------
file_object : StringIO
The input string buffer.
file_path : string, optional
File to save Record to.
replay : bool, optional
Determines whether or not the object is in playback mode. If not
in playback mode, the content of record will be written to the
file. If in playback mode, the content of file is loaded into the
record.
"""
assert
file_object
is
not
None
or
file_path
is
not
None
...
...
@@ -30,6 +75,18 @@ class Record(object):
self
.
__dict__
.
update
(
locals
())
def
handle_line
(
self
,
line
):
"""
If not in playback mode, it records a new string. If in playback mode,
it compares the current string to the next element in the sequence.
If these are identical the element is removed and otherwise a mismatch
exception is raised.
Parameters
----------
line : string
The string to record.
"""
assert
line
.
endswith
(
'
\n
'
)
assert
line
[:
-
2
]
.
find
(
'
\n
'
)
==
-
1
if
self
.
replay
:
...
...
@@ -50,20 +107,62 @@ class Record(object):
else
:
self
.
f
.
write
(
line
)
class
RecordMode
(
Mode
):
"""
Records all computations done with a function in a file at output_path
Prints the index of each apply node and md5 digests of the numpy ndarrays
it receives as inputs and produces as outputs.
Records all computations done with a function in a file at output_path.
Writes into the file the index of each apply node and md5 digests of the
numpy ndarrays it receives as inputs and produces as output.
Example:
# We use RecordMode to test that the computation of a function is
identical. Create a Record object and use it to initialize a
RecordMode object.
output = cStringIO.StringIO()
record = Record(file_object=output, replay=False)
record_mode = RecordMode(record)
# Then call the function you wish to test, which uses Apply nodes,
# with record_mode as first parameter to record all the computations
# to file.
...
# Create a new RecordMode object and initialize it with the previous
# record.
output = cStringIO.StringIO(output.getvalue())
playback = Record(file_object=output, replay=True)
playback_mode = RecordMode(playback)
# Call the function to test again with record_mode as first parameter.
# An exception will be thrown if the recorded computations are not
# identical between the two runs.
...
"""
def
set_record
(
self
,
record
):
"""
Configure object to use an existing Record object.
Parameters
----------
record : Record
The Record object to use.
"""
self
.
record
=
record
self
.
known_fgraphs
=
set
([])
def
__init__
(
self
,
record
=
None
,
**
kwargs
):
"""
Takes either a Record object or the keyword arguments to make one.
Parameters
----------
record : Record
The existing Record object to use.
kwargs : pointer?
Keyword arguments to construct new object.
"""
if
record
is
None
:
...
...
@@ -73,15 +172,26 @@ class RecordMode(Mode):
self
.
set_record
(
record
)
def
handle_line
(
line
,
i
,
node
,
fn
):
"""
Records new node computation.
Parameters
----------
line : string
Name of function?
i : integer?
unique id of node?
node : ???
fn : ???
"""
try
:
self
.
record
.
handle_line
(
line
)
except
MismatchError
,
e
:
print
'Got this MismatchError:'
print
e
print
'while processing node i='
+
str
(
i
)
+
':'
print
'str(node):'
,
str
(
node
)
print
'str(node):'
,
str
(
node
)
print
'Symbolic inputs: '
for
elem
in
node
.
inputs
:
print
theano
.
printing
.
min_informative_str
(
elem
)
...
...
@@ -94,26 +204,33 @@ class RecordMode(Mode):
raise
MismatchError
(
"Non-determinism detected by WrapLinker"
)
def
callback
(
i
,
node
,
fn
):
"""
Function called by Apply nodes at the end of each computation?
"""
fgraph
=
node
.
fgraph
if
fgraph
.
name
is
None
:
raise
ValueError
(
"Un-named functions are not allowed with RecordMode, "
"because they make it impossible to tell if the same function is "
"running during the playback."
)
"because they make it impossible to tell if the same function is "
"running during the playback."
)
if
fgraph
not
in
self
.
known_fgraphs
:
assert
not
any
([
elem
.
name
==
fgraph
.
name
for
elem
in
self
.
known_fgraphs
])
assert
not
any
([
elem
.
name
==
fgraph
.
name
for
elem
in
self
.
known_fgraphs
])
self
.
known_fgraphs
.
add
(
fgraph
)
num_app
=
len
(
fgraph
.
apply_nodes
)
line
=
'Function '
+
fgraph
.
name
+
' has '
+
str
(
num_app
)
+
' apply nodes.
\n
'
line
=
'Function '
+
fgraph
.
name
+
' has '
+
str
(
num_app
)
\
+
' apply nodes.
\n
'
handle_line
(
line
,
i
,
node
,
fn
)
line
=
'Function name: '
+
fgraph
.
name
+
'
\n
'
handle_line
(
line
,
i
,
node
,
fn
)
line
=
'Node '
+
str
(
i
)
+
':'
+
str
(
node
)
+
'
\n
'
handle_line
(
line
,
i
,
node
,
fn
)
assert
all
([
isinstance
(
x
,
list
)
and
len
(
x
)
==
1
for
x
in
fn
.
inputs
])
assert
all
([
isinstance
(
x
,
list
)
and
len
(
x
)
==
1
for
x
in
fn
.
inputs
])
def
digest
(
x
):
x
=
x
[
0
]
return
hex_digest
(
x
)
...
...
@@ -125,7 +242,7 @@ class RecordMode(Mode):
line
=
'Outputs: '
+
outputs_digest
+
'
\n
'
handle_line
(
line
,
i
,
node
,
fn
)
#linker = theano.gof.OpWiseCLinker()
#
linker = theano.gof.OpWiseCLinker()
linker
=
theano
.
gof
.
vm
.
VM_Linker
(
use_cloop
=
bool
(
theano
.
config
.
cxx
))
wrap_linker
=
theano
.
gof
.
WrapLinkerMany
([
linker
],
[
callback
])
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论