Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
c104defb
提交
c104defb
authored
3月 11, 2009
作者:
james@X40
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
more changes to Module, more tests
上级
c31223aa
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
224 行增加
和
93 行删除
+224
-93
module.py
theano/compile/module.py
+135
-74
test_module.py
theano/compile/tests/test_module.py
+89
-19
没有找到文件。
theano/compile/module.py
浏览文件 @
c104defb
...
...
@@ -32,14 +32,14 @@ import function_module as F
from
mode
import
default_mode
def
join
(
*
args
):
def
name_
join
(
*
args
):
"""
Creates a string representation for the given names:
join('a', 'b', 'c') => 'a.b.c'
"""
return
"."
.
join
(
arg
for
arg
in
args
if
arg
)
def
split
(
sym
,
n
=-
1
):
def
name_
split
(
sym
,
n
=-
1
):
"""
Gets the names from their joined representation
split('a.b.c') => ['a', 'b', 'c']
...
...
@@ -55,7 +55,7 @@ def canonicalize(name):
[Fred: why we return the right type? Why int only?]
"""
if
isinstance
(
name
,
str
):
name
=
split
(
name
)
name
=
name_
split
(
name
)
def
convert
(
x
):
try
:
return
int
(
x
)
...
...
@@ -63,7 +63,6 @@ def canonicalize(name):
return
x
return
map
(
convert
,
name
)
class
AllocationError
(
Exception
):
"""
Exception raised when a Result has no associated storage.
...
...
@@ -116,7 +115,7 @@ class Component(object):
else
:
raise
BindError
(
"
%
s is already bound to
%
s as
%
s"
%
(
self
,
self
.
parent
,
self
.
name
))
self
.
parent
=
parent
self
.
name
=
join
(
parent
.
name
,
name
)
self
.
name
=
name_
join
(
parent
.
name
,
name
)
return
self
def
bound
(
self
):
...
...
@@ -303,29 +302,79 @@ class Member(_RComponent):
return
memo
[
self
.
r
]
.
value
class
Method
(
Component
):
"""
Method is a declaration of a function. It contains inputs,
outputs and updates. If the Method is part of a Composite
which holds references to Members, the Method may use them
without declaring them in the inputs, outputs or updates list.
def
__init__
(
self
,
inputs
,
outputs
,
updates
=
{},
kits
=
[],
**
kwupdates
):
"""
Method is a declaration of a function. It contains inputs,
outputs and updates. If the Method is part of a Composite
which holds references to Members, the Method may use them
without declaring them in the inputs, outputs or updates list.
inputs, outputs or updates may be strings. In that case, they
will be resolved in the Composite which is the parent of this
Method.
Method builds a Function (same structure as a call to
theano.function)
"""
inputs
=
[]
"""function inputs (see `compile.function`)
If Module members are named explicitly in this list, then they will not use shared storage.
Storage must be provided either via an `io.In` value argument, or at the point of the
function call.
"""
outputs
=
None
"""function outputs (see `compile.function`)"""
updates
=
{}
"""update expressions for module members
If this method should update the shared storage value for a Module member, then the
update expression must be given in this dictionary.
Keys in this dictionary must be members of the module graph--results for which this Method
will use the shared storage.
The value associated with each key should be a Result (or a string that can be resolved to
a Result) representing the computation of a new value for this shared storage after
each function call.
"""
mode
=
None
"""This will override the Module compilation mode for this Method"""
def
__init__
(
self
,
inputs
,
outputs
,
updates
=
{},
mode
=
None
,
**
kwupdates
):
"""Initialize attributes
:param inputs: value for `Method.inputs`
:param outputs: value for `Method.outputs`
:param updates: value for `Method.updates`
:param kwupdates: additions to `updates`
:param mode: value for `Method.mode`
:type inputs: list of (str or `Result` or `io.In`)
:type outputs: None or str or `Result` or `io.Out` or list of (str or `Result` or
`io.Out`)
[TODO: remove references to kits, for they are not really
needed anymore]
:type updates: dict of `Result` or str -> `Result` or str
inputs, outputs or updates may be strings. In that case, they
will be resolved in the Composite which is the parent of this
Method.
:type kwupdates: extra updates
:type mode: None or any mode accepted by `compile.function`
Method builds a Function (same structure as a call to
theano.function)
"""
super
(
Method
,
self
)
.
__init__
()
self
.
inputs
=
inputs
self
.
outputs
=
outputs
self
.
updates
=
dict
(
updates
,
**
kwupdates
)
self
.
kits
=
list
(
kits
)
self
.
mode
=
mode
def
bind
(
self
,
parent
,
name
,
dup_ok
=
True
):
rval
=
super
(
Method
,
self
)
.
bind
(
parent
,
name
,
dup_ok
=
dup_ok
)
...
...
@@ -333,8 +382,12 @@ class Method(Component):
return
rval
def
resolve
(
self
,
name
):
"""
Resolves the name of an input or output in the parent.
"""Return the Result corresponding to a given name
:param name: the name of a Result in the Module instance containing this Method
:type name: str
:rtype: `Result`
"""
if
not
self
.
bound
():
raise
ValueError
(
'Trying to resolve a name on an unbound Method.'
)
...
...
@@ -343,49 +396,47 @@ class Method(Component):
raise
TypeError
(
'Expected a Component with subtype Member or External.'
)
return
result
def
resolve_result
(
self
,
x
,
passthrough
=
(
gof
.
Result
)):
if
isinstance
(
x
,
passthrough
):
return
x
elif
isinstance
(
x
,
_RComponent
):
return
x
.
r
else
:
return
self
.
resolve
(
x
)
.
r
def
resolve_all
(
self
):
"""Convert all inputs, outputs, and updates specified as strings to Results.
def
resolve_inputs
(
self
):
if
isinstance
(
self
.
inputs
,
(
io
.
In
,
gof
.
Result
,
str
)):
inputs
=
[
self
.
inputs
]
else
:
inputs
=
list
(
self
.
inputs
)
self
.
inputs
=
[
self
.
resolve_result
(
input
,
passthrough
=
(
gof
.
Result
,
io
.
In
))
for
input
in
inputs
]
def
resolve_outputs
(
self
):
if
isinstance
(
self
.
outputs
,
(
io
.
Out
,
gof
.
Result
,
str
)):
output
=
self
.
outputs
self
.
outputs
=
self
.
resolve_result
(
output
,
passthrough
=
(
gof
.
Result
,
io
.
Out
))
else
:
outputs
=
list
(
self
.
outputs
)
self
.
outputs
=
[
self
.
resolve_result
(
output
,
passthrough
=
(
gof
.
Result
,
io
.
Out
))
for
output
in
outputs
]
This works by searching the containing Module for Result attributes by these names.
"""
def
resolve_result
(
x
,
passthrough
=
(
gof
.
Result
)):
if
isinstance
(
x
,
passthrough
):
return
x
elif
isinstance
(
x
,
_RComponent
):
return
x
.
r
else
:
return
self
.
resolve
(
x
)
.
r
def
resolve_updates
(
self
):
updates
=
self
.
updates
self
.
updates
=
{}
for
k
,
v
in
updates
.
iteritems
():
k
,
v
=
self
.
resolve_result
(
k
),
self
.
resolve_result
(
v
)
self
.
updates
[
k
]
=
v
def
resolve_inputs
():
if
isinstance
(
self
.
inputs
,
(
io
.
In
,
gof
.
Result
,
str
)):
inputs
=
[
self
.
inputs
]
else
:
inputs
=
list
(
self
.
inputs
)
self
.
inputs
=
[
resolve_result
(
input
,
passthrough
=
(
gof
.
Result
,
io
.
In
))
for
input
in
inputs
]
def
resolve_outputs
():
if
isinstance
(
self
.
outputs
,
(
io
.
Out
,
gof
.
Result
,
str
,
None
)):
output
=
self
.
outputs
self
.
outputs
=
resolve_result
(
output
,
passthrough
=
(
gof
.
Result
,
io
.
Out
,
None
))
else
:
outputs
=
list
(
self
.
outputs
)
self
.
outputs
=
[
resolve_result
(
output
,
passthrough
=
(
gof
.
Result
,
io
.
Out
))
for
output
in
outputs
]
def
resolve_all
(
self
):
"""
Resolves all inputs, outputs and updates that were given as
strings so that the fields contain the corresponding Result
instances instead.
"""
self
.
resolve_inputs
()
self
.
resolve_outputs
()
self
.
resolve_updates
()
def
resolve_updates
():
updates
=
self
.
updates
self
.
updates
=
{}
for
k
,
v
in
updates
.
iteritems
():
k
,
v
=
resolve_result
(
k
),
resolve_result
(
v
)
self
.
updates
[
k
]
=
v
resolve_inputs
()
resolve_outputs
()
resolve_updates
()
def
allocate
(
self
,
memo
):
"""
...
...
@@ -394,13 +445,21 @@ class Method(Component):
return
None
def
build
(
self
,
mode
,
memo
,
allocate_all
=
False
):
"""
Produces a function. If allocate_all is True, storage will be
allocated for all needed Results, even if there is no
"""Compile a function for this Method.
:param allocate_all: if True, storage will be
allocated for all needed Results even if there is no
associated storage for them in the memo. If allocate_all is
False, storage will only be allocated for Results that are
reachable from the inputs list.
:returns: a function that implements this method
:rtype: `Function` instance
"""
if
self
in
memo
:
return
memo
[
self
]
self
.
resolve_all
()
# resolve all so we don't have to mess with strings
def
get_storage
(
r
,
require
=
False
):
# If require is True, we can only get storage from the memo.
...
...
@@ -430,7 +489,7 @@ class Method(Component):
else
:
raise
TypeError
(
input
,
type
(
input
))
# Deal with updates
# Deal with updates
to shared storage
for
k
,
v
in
self
.
updates
.
iteritems
():
assert
isinstance
(
k
,
gof
.
Result
)
assert
isinstance
(
v
,
gof
.
Result
)
...
...
@@ -441,7 +500,7 @@ class Method(Component):
if
input
.
result
==
k
:
input_k
=
input
print
'METHOD UPDATE'
,
k
,
v
,
input_k
#
print 'METHOD UPDATE', k, v, input_k
if
input_k
is
None
:
# this is an implicit input,
# use shared storage
...
...
@@ -452,10 +511,9 @@ class Method(Component):
mutable
=
True
)
inputs
.
append
(
input_k
)
else
:
# this was an explicit input
# don't use shared storage
input_k
.
update
=
v
input_k
.
mutable
=
True
raise
ValueError
((
'Result listed in both inputs and updates.'
' Use inputs to use your own storage, use updates to '
'work on module-shared storage'
),
k
)
outputs
=
self
.
outputs
_inputs
=
[
x
.
result
for
x
in
inputs
]
...
...
@@ -478,7 +536,10 @@ class Method(Component):
assert
type
(
storage
)
is
io
.
In
inputs
.
append
(
storage
)
return
F
.
function
(
inputs
,
outputs
,
mode
)
effective_mode
=
mode
if
self
.
mode
is
None
else
self
.
mode
rval
=
F
.
function
(
inputs
,
outputs
,
effective_mode
)
memo
[
self
]
=
rval
return
rval
def
pretty
(
self
,
**
kwargs
):
self
.
resolve_all
()
...
...
@@ -507,10 +568,10 @@ class Method(Component):
def
dup
(
self
):
self
.
resolve_all
()
return
self
.
__class__
(
list
(
self
.
inputs
),
list
(
self
.
outputs
)
if
isinstance
(
self
.
outputs
,
list
)
else
self
.
outputs
,
dict
(
self
.
updates
),
list
(
self
.
kits
)
)
return
self
.
__class__
(
inputs
=
list
(
self
.
inputs
),
outputs
=
list
(
self
.
outputs
)
if
isinstance
(
self
.
outputs
,
list
)
else
self
.
outputs
,
updates
=
dict
(
self
.
updates
),
mode
=
self
.
mode
)
def
__call__
(
self
,
*
args
,
**
kwargs
):
raise
TypeError
(
"'Method' object is not callable"
...
...
theano/compile/tests/test_module.py
浏览文件 @
c104defb
...
...
@@ -267,6 +267,7 @@ class T_module(unittest.TestCase):
def
get_element
(
i
):
return
[
i
.
x
,
i
.
lx
[
0
],
i
.
tx
[
0
],
i
.
dx
[
'x'
],
i
.
llx
[
0
][
0
],
i
.
llx
[
1
][
0
],
i
.
ltx
[
0
][
0
],
i
.
ldx
[
0
][
'x'
],
i
.
tlx
[
0
][
0
],
i
.
tlx
[
0
][
0
],
i
.
tdx
[
0
][
'x'
],
i
.
dlx
[
'x'
][
0
],
i
.
dtx
[
'x'
][
0
],
i
.
ddx
[
'x'
][
'x'
]]
m1
=
Module
()
m2
=
Module
()
x
=
T
.
dscalar
()
...
...
@@ -393,7 +394,13 @@ class T_module(unittest.TestCase):
assert
isinstance
(
inst
.
dy
[
'y'
],
theano
.
compile
.
function_module
.
Function
)
assert
isinstance
(
inst
.
tty
[
0
][
0
],
theano
.
compile
.
function_module
.
Function
)
print
>>
sys
.
stderr
,
"MODULE TEST IMPLEMENTED BUT WE DON'T KNOW WHAT WE WANT AS A RESULT"
assert
m1
.
y
is
m1
.
ly
[
0
]
assert
inst
.
y
is
inst
.
ly
[
0
]
assert
inst
.
y
is
inst
.
lly
[
0
][
0
]
assert
inst
.
y
is
inst
.
ty
[
0
]
assert
inst
.
y
is
inst
.
tty
[
0
][
0
]
assert
inst
.
y
is
inst
.
dy
[
'y'
]
def
test_member_method_inputs
(
self
):
"""Test that module Members can be named as Method inputs, in which case the function will
...
...
@@ -416,7 +423,6 @@ class T_module(unittest.TestCase):
assert
m
.
y
==
77
assert
m
.
x
==
1000
def
test_member_input_flags
(
self
):
"""Test that we can manipulate the mutable, strict, etc. flags (see SymbolicInput) of
Method inputs"""
...
...
@@ -448,20 +454,30 @@ class T_module(unittest.TestCase):
m
.
f
([
3
,
2
])
assert
numpy
.
all
(
v0
!=
v0_copy
)
def
test_sanity_check_mode
(
self
):
"""Test that Module.make(self) can take the same list of Modes that function can, so we can
debug modules"""
print
>>
sys
.
stderr
,
"WARNING MODULE TEST NOT IMPLEMENTED"
def
test_member_value
(
self
):
"""Test that module Members of Value work correctly. As Result?"""
print
>>
sys
.
stderr
,
"WARNING MODULE TEST NOT IMPLEMENTED"
M
=
Module
()
x
=
T
.
dscalar
()
M
.
y
=
T
.
value
(
40
)
M
.
f
=
Method
([
x
],
x
+
2
*
M
.
y
)
m
=
M
.
make
()
m
.
y
=
80
assert
m
.
f
(
20
)
==
180
def
test_member_constant
(
self
):
"""Test that module Members of Constant work correctly.
As Result with more optimization?"""
print
>>
sys
.
stderr
,
"WARNING MODULE TEST NOT IMPLEMENTED"
M
=
Module
()
x
=
T
.
dscalar
()
M
.
y
=
T
.
constant
(
40
)
M
.
f
=
Method
([
x
],
x
+
2
*
M
.
y
)
m
=
M
.
make
()
try
:
m
.
y
=
77
#fail?
assert
0
#assign to constant should not have worked
except
:
pass
assert
m
.
f
(
20
)
==
100
def
test_raise_NotImplemented
(
self
):
c
=
Component
()
...
...
@@ -476,18 +492,72 @@ class T_module(unittest.TestCase):
self
.
assertRaises
(
NotImplementedError
,
c
.
get
,
"n"
)
self
.
assertRaises
(
NotImplementedError
,
c
.
set
,
"n"
,
1
)
def
test_tuple_members
(
self
):
def
test_tuple_members
(
):
M
=
Module
()
M
.
a
=
(
1
,
1
)
assert
isinstance
(
M
.
a
,
tuple
)
M
=
Module
()
M
.
a
=
(
1
,
1
)
assert
isinstance
(
M
.
a
,
tuple
)
class
Temp
(
Module
):
def
__init__
(
self
):
self
.
a
=
(
1
,
1
)
M
=
Temp
()
assert
isinstance
(
M
.
a
,
tuple
)
class
Temp
(
Module
):
def
__init__
(
self
):
self
.
a
=
(
1
,
1
)
M
=
Temp
()
assert
isinstance
(
M
.
a
,
tuple
)
def
test_method_updates
():
# updates work
M
=
Module
()
M
.
x
=
T
.
dvector
()
x
=
T
.
dvector
()
xval
=
numpy
.
asarray
([
0
,
0.5
])
M
.
f
=
Method
([
x
],
M
.
x
*
4
,
updates
=
{
M
.
x
:
M
.
x
*
2
},
mode
=
'FAST_COMPILE'
)
m
=
M
.
make
(
mode
=
'FAST_RUN'
)
m
.
x
=
xval
m
.
f
([
9
,
9
])
assert
numpy
.
all
(
m
.
x
==
[
0
,
1
])
assert
numpy
.
all
(
xval
==
[
0
,
0.5
])
# In(update) works
M
=
Module
()
M
.
x
=
T
.
dvector
()
x
=
T
.
dvector
()
M
.
f
=
Method
([
x
,
io
.
In
(
M
.
x
,
value
=
xval
,
update
=
M
.
x
*
2
)],
M
.
x
*
4
)
m
=
M
.
make
()
m
.
f
([
9
,
9
])
assert
m
.
x
is
None
assert
numpy
.
all
(
xval
==
[
0
,
1
])
# when a result is listed explicitly and in an update, then there's a problem.
M
=
Module
()
M
.
x
=
T
.
dvector
()
x
=
T
.
dvector
()
M
.
f
=
Method
([
x
,
io
.
In
(
M
.
x
,
value
=
xval
,
update
=
M
.
x
*
2
)],
M
.
x
*
4
,
updates
=
{
M
.
x
:
M
.
x
*
7
})
try
:
m
=
M
.
make
()
assert
False
except
ValueError
,
e
:
if
str
(
e
[
0
])
.
startswith
(
'Result listed in both inputs and up'
):
pass
else
:
raise
def
test_method_mode
():
"""Test that Methods can override the module build mode"""
M
=
Module
()
M
.
x
=
T
.
dvector
()
M
.
f
=
Method
([
M
.
x
],
M
.
x
*
4
,
mode
=
'FAST_COMPILE'
)
M
.
g
=
Method
([
M
.
x
],
M
.
x
*
4
)
M
.
h
=
Method
([
M
.
x
],
M
.
x
*
4
)
m
=
M
.
make
(
mode
=
'FAST_RUN'
)
assert
m
.
f
.
maker
.
mode
!=
m
.
g
.
maker
.
mode
assert
m
.
h
.
maker
.
mode
==
m
.
g
.
maker
.
mode
assert
numpy
.
all
(
m
.
f
([
1
,
2
])
==
m
.
g
([
1
,
2
]))
def
test_pickle
():
"""Test that a module can be pickled"""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论