Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7f862fdc
提交
7f862fdc
authored
2月 06, 2012
作者:
Olivier Delalleau
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
PEP8 fixes
上级
5c576b65
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
81 行增加
和
84 行删除
+81
-84
io.py
theano/compile/io.py
+5
-5
pfunc.py
theano/compile/pfunc.py
+76
-79
没有找到文件。
theano/compile/io.py
浏览文件 @
7f862fdc
...
...
@@ -5,7 +5,8 @@ from theano import gof
from
sharedvalue
import
SharedVariable
import
logging
_logger
=
logging
.
getLogger
(
"theano.compile.io"
)
_logger
=
logging
.
getLogger
(
"theano.compile.io"
)
class
SymbolicInput
(
object
):
"""
...
...
@@ -49,7 +50,7 @@ class SymbolicInput(object):
def
__init__
(
self
,
variable
,
name
=
None
,
update
=
None
,
mutable
=
None
,
strict
=
False
,
allow_downcast
=
None
,
autoname
=
True
,
implicit
=
False
):
assert
implicit
is
not
None
# Safety check.
assert
implicit
is
not
None
# Safety check.
self
.
variable
=
variable
if
(
autoname
and
name
is
None
):
self
.
name
=
variable
.
name
...
...
@@ -194,8 +195,7 @@ class In(SymbolicInput):
# try to keep it synchronized.
def
__init__
(
self
,
variable
,
name
=
None
,
value
=
None
,
update
=
None
,
mutable
=
None
,
strict
=
False
,
allow_downcast
=
None
,
autoname
=
True
,
implicit
=
None
,
borrow
=
None
,
shared
=
False
):
implicit
=
None
,
borrow
=
None
,
shared
=
False
):
#if shared, an input's value comes from its persistent storage, not from a default stored
#in the function or from the caller
...
...
@@ -206,7 +206,7 @@ class In(SymbolicInput):
# mutable=True should require borrow=True. Raise warning when borrow is explicitely set
# to False with mutable=True.
if
mutable
:
if
borrow
==
False
:
if
borrow
==
False
:
_logger
.
warning
(
"Symbolic input for variable
%
s (name=
%
s) has "
"flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the "
...
...
theano/compile/pfunc.py
浏览文件 @
7f862fdc
"""Provide a simple user friendly API """
__docformat__
=
'restructuredtext en'
import
numpy
# for backport to 2.4, to get any().
from
profiling
import
ProfileStats
from
theano.gof
import
Container
,
Variable
,
generic
,
graph
,
Constant
,
Value
from
theano
import
config
from
theano.compile
import
orig_function
,
In
,
Out
from
theano.compile.sharedvalue
import
SharedVariable
,
shared
from
theano
import
config
from
theano.gof
import
Container
,
Variable
,
generic
,
graph
,
Constant
,
Value
from
theano.gof.python25
import
any
import
logging
_logger
=
logging
.
getLogger
(
"theano.compile.pfunc"
)
def
rebuild_collect_shared
(
outputs
,
inputs
=
None
,
replace
=
None
,
updates
=
None
,
rebuild_strict
=
True
,
copy_inputs_over
=
True
,
no_default_updates
=
False
_logger
=
logging
.
getLogger
(
"theano.compile.pfunc"
)
def
rebuild_collect_shared
(
outputs
,
inputs
=
None
,
replace
=
None
,
updates
=
None
,
rebuild_strict
=
True
,
copy_inputs_over
=
True
,
no_default_updates
=
False
,
):
"""
Function that allows replacing subgraphs of a computational
...
...
@@ -60,7 +63,7 @@ def rebuild_collect_shared( outputs
"""
if
isinstance
(
outputs
,
tuple
):
if
isinstance
(
outputs
,
tuple
):
outputs
=
list
(
outputs
)
## This function implements similar functionality as graph.clone
...
...
@@ -71,7 +74,6 @@ def rebuild_collect_shared( outputs
# list of shared inputs that are used as inputs of the graph
shared_inputs
=
[]
def
clone_v_get_shared_updates
(
v
,
copy_inputs_over
):
'''
Clones a variable and its inputs recursively until all are in
...
...
@@ -88,36 +90,34 @@ def rebuild_collect_shared( outputs
return
clone_d
[
v
]
if
v
.
owner
:
clone_a
(
v
.
owner
,
copy_inputs_over
)
return
clone_d
.
setdefault
(
v
,
v
)
return
clone_d
.
setdefault
(
v
,
v
)
elif
isinstance
(
v
,
SharedVariable
):
if
v
not
in
shared_inputs
:
shared_inputs
.
append
(
v
)
if
hasattr
(
v
,
'default_update'
):
# Check that v should not be excluded from the default
# updates list
if
(
no_default_updates
is
False
or
(
isinstance
(
no_default_updates
,
list
)
and
v
not
in
no_default_updates
)
):
if
(
no_default_updates
is
False
or
(
isinstance
(
no_default_updates
,
list
)
and
v
not
in
no_default_updates
)):
# Do not use default_update if a "real" update was
# provided
if
v
not
in
update_d
:
v_update
=
v
.
type
.
filter_variable
(
v
.
default_update
)
if
v_update
.
type
!=
v
.
type
:
raise
TypeError
(
(
'an update must have the same type as '
'the original shared variable'
)
,
(
v
,
v
.
type
,
v_update
,
v_update
.
type
))
'an update must have the same type as '
'the original shared variable'
,
(
v
,
v
.
type
,
v_update
,
v_update
.
type
))
update_d
[
v
]
=
v_update
update_expr
.
append
((
v
,
v_update
))
if
not
copy_inputs_over
or
(
isinstance
(
v
,
Constant
)
and
hasattr
(
v
,
'env'
)):
hasattr
(
v
,
'env'
)):
### Cloning shared variables implies copying their underlying
### memory buffer ?? No.
return
clone_d
.
setdefault
(
v
,
v
.
clone
())
return
clone_d
.
setdefault
(
v
,
v
.
clone
())
else
:
return
clone_d
.
setdefault
(
v
,
v
)
return
clone_d
.
setdefault
(
v
,
v
)
def
clone_a
(
a
,
copy_inputs_over
):
'''
...
...
@@ -132,12 +132,11 @@ def rebuild_collect_shared( outputs
clone_d
[
a
]
=
a
.
clone_with_new_inputs
([
clone_d
[
i
]
for
i
in
a
.
inputs
],
strict
=
rebuild_strict
)
strict
=
rebuild_strict
)
for
old_o
,
new_o
in
zip
(
a
.
outputs
,
clone_d
[
a
]
.
outputs
):
clone_d
.
setdefault
(
old_o
,
new_o
)
clone_d
.
setdefault
(
old_o
,
new_o
)
return
clone_d
[
a
]
# intialize the clone_d mapping with the replace dictionary
if
replace
is
None
:
replace
=
[]
...
...
@@ -147,9 +146,9 @@ def rebuild_collect_shared( outputs
replace_pairs
=
replace
for
v_orig
,
v_repl
in
replace_pairs
:
if
not
isinstance
(
v_orig
,
Variable
):
if
not
isinstance
(
v_orig
,
Variable
):
raise
TypeError
(
'given keys must be Variable'
,
v_orig
)
if
not
isinstance
(
v_repl
,
Variable
):
if
not
isinstance
(
v_repl
,
Variable
):
v_repl
=
shared
(
v_repl
)
assert
v_orig
not
in
clone_d
clone_d
[
v_orig
]
=
clone_v_get_shared_updates
(
v_repl
,
...
...
@@ -160,9 +159,9 @@ def rebuild_collect_shared( outputs
def
clone_inputs
(
i
):
if
not
copy_inputs_over
:
return
clone_d
.
setdefault
(
i
,
i
.
clone
())
return
clone_d
.
setdefault
(
i
,
i
.
clone
())
else
:
return
clone_d
.
setdefault
(
i
,
i
)
return
clone_d
.
setdefault
(
i
,
i
)
input_variables
=
[
clone_inputs
(
i
)
for
i
in
inputs
]
...
...
@@ -171,7 +170,7 @@ def rebuild_collect_shared( outputs
# it is also not clear when/how to use the value of that shared
# variable (is it a default? ignored?, if the shared variable changes,
# does that function default also change?).
if
numpy
.
any
([
isinstance
(
v
,
SharedVariable
)
for
v
in
input_variables
]):
if
any
([
isinstance
(
v
,
SharedVariable
)
for
v
in
input_variables
]):
raise
TypeError
((
'Cannot use a shared variable (
%
s) as explicit '
'input. Consider substituting a non-shared'
' variable via the `givens` parameter'
)
%
v
)
...
...
@@ -181,25 +180,25 @@ def rebuild_collect_shared( outputs
updates
=
[]
for
(
store_into
,
update_val
)
in
iter_over_pairs
(
updates
):
if
not
isinstance
(
store_into
,
SharedVariable
):
raise
TypeError
(
'update target must be a SharedVariable'
,
store_into
)
raise
TypeError
(
'update target must be a SharedVariable'
,
store_into
)
if
store_into
in
update_d
:
raise
ValueError
(
(
'this shared variable already has an update '
'expression'
)
,
(
store_into
,
update_d
[
store_into
]))
raise
ValueError
(
'this shared variable already has an update '
'expression'
,
(
store_into
,
update_d
[
store_into
]))
# filter_variable ensure smooth conversion of cpu/gpu Types
update_val
=
store_into
.
type
.
filter_variable
(
update_val
)
if
update_val
.
type
!=
store_into
.
type
:
err_msg
=
(
'an update must have the same type as the '
'original shared variable(dest, dest.type, '
'update_val, update_val.type)'
)
err_arg
=
(
store_into
,
store_into
.
type
,
update_val
,
update_val
.
type
)
raise
TypeError
(
err_msg
,
err_arg
)
err_msg
=
(
'an update must have the same type as the '
'original shared variable(dest, dest.type, '
'update_val, update_val.type)'
)
err_arg
=
(
store_into
,
store_into
.
type
,
update_val
,
update_val
.
type
)
raise
TypeError
(
err_msg
,
err_arg
)
update_d
[
store_into
]
=
update_val
update_expr
.
append
((
store_into
,
update_val
))
...
...
@@ -215,8 +214,8 @@ def rebuild_collect_shared( outputs
copy_inputs_over
)
cloned_outputs
.
append
(
Out
(
cloned_v
,
borrow
=
v
.
borrow
))
else
:
raise
TypeError
(
(
'outputs must be theano Variable or '
'Out instances'
)
,
v
)
raise
TypeError
(
'outputs must be theano Variable or '
'Out instances'
,
v
)
#computed_list.append(cloned_v)
else
:
if
isinstance
(
outputs
,
Variable
):
...
...
@@ -229,12 +228,11 @@ def rebuild_collect_shared( outputs
cloned_outputs
=
Out
(
cloned_v
,
borrow
=
outputs
.
borrow
)
#computed_list.append(cloned_v)
elif
outputs
is
None
:
cloned_outputs
=
[]
# TODO: get Function.__call__ to return None
cloned_outputs
=
[]
# TODO: get Function.__call__ to return None
else
:
raise
TypeError
(
(
'output must be a theano Variable or Out '
'instance (or list of them)'
)
,
outputs
)
raise
TypeError
(
'output must be a theano Variable or Out '
'instance (or list of them)'
,
outputs
)
# Iterate over update_expr, cloning its elements, and updating
# shared_inputs, update_d and update_expr from the SharedVariables
...
...
@@ -244,7 +242,7 @@ def rebuild_collect_shared( outputs
# Note: we extend update_expr while iterating over it.
i
=
0
while
i
<
len
(
update_expr
):
while
i
<
len
(
update_expr
):
v
,
v_update
=
update_expr
[
i
]
cloned_v_update
=
clone_v_get_shared_updates
(
v_update
,
copy_inputs_over
)
...
...
@@ -253,12 +251,13 @@ def rebuild_collect_shared( outputs
shared_inputs
.
append
(
v
)
i
+=
1
return
(
input_variables
,
cloned_outputs
,
[
clone_d
,
update_d
,
update_expr
,
shared_inputs
]
)
return
(
input_variables
,
cloned_outputs
,
[
clone_d
,
update_d
,
update_expr
,
shared_inputs
])
class
Param
(
object
):
def
__init__
(
self
,
variable
,
default
=
None
,
name
=
None
,
mutable
=
False
,
strict
=
False
,
allow_downcast
=
None
,
implicit
=
None
,
borrow
=
None
):
strict
=
False
,
allow_downcast
=
None
,
implicit
=
None
,
borrow
=
None
):
"""
:param variable: A variable in an expression graph to use as a compiled-function parameter
...
...
@@ -295,7 +294,7 @@ class Param(object):
# mutable=True should require borrow=True. Raise warning when borrow is explicitely set
# to False with mutable=True.
if
mutable
:
if
borrow
==
False
:
if
not
borrow
:
_logger
.
warning
(
"Symbolic input for variable
%
s (name=
%
s) has "
"flags mutable=True, borrow=False. This combination is "
"incompatible since mutable=True implies that the "
...
...
@@ -308,6 +307,7 @@ class Param(object):
self
.
implicit
=
implicit
self
.
borrow
=
borrow
def
pfunc
(
params
,
outputs
=
None
,
mode
=
None
,
updates
=
[],
givens
=
[],
no_default_updates
=
False
,
accept_inplace
=
False
,
name
=
None
,
rebuild_strict
=
True
,
allow_input_downcast
=
None
,
...
...
@@ -398,27 +398,25 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
# No need to block other objects being passed through though. It might be
# useful.
if
not
isinstance
(
params
,
(
list
,
tuple
)):
if
not
isinstance
(
params
,
(
list
,
tuple
)):
raise
Exception
(
"in pfunc() the first argument must be a list or a tuple"
)
if
not
isinstance
(
no_default_updates
,
bool
)
\
and
not
isinstance
(
no_default_updates
,
list
):
raise
TypeError
(
"no_default_update should be either a boolean or a list"
)
# transform params into theano.compile.In objects.
inputs
=
[
_pfunc_param_to_in
(
p
,
allow_downcast
=
allow_input_downcast
)
for
p
in
params
]
in_variables
=
[
input
.
variable
for
input
in
inputs
]
output_vars
=
rebuild_collect_shared
(
outputs
,
in_variables
,
replace
=
givens
,
updates
=
updates
,
rebuild_strict
=
True
,
copy_inputs_over
=
True
,
no_default_updates
=
no_default_updates
)
in_variables
=
[
input
.
variable
for
input
in
inputs
]
output_vars
=
rebuild_collect_shared
(
outputs
,
in_variables
,
replace
=
givens
,
updates
=
updates
,
rebuild_strict
=
True
,
copy_inputs_over
=
True
,
no_default_updates
=
no_default_updates
)
# extracting the arguments
input_variables
,
cloned_outputs
,
other_stuff
=
output_vars
clone_d
,
update_d
,
update_expr
,
shared_inputs
=
other_stuff
...
...
@@ -431,14 +429,13 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
#value will be stored in the resulting functions' defaults list
#but since the value of shared variables never needs to be refed, it is not needed
if
sv
in
update_d
:
si
=
In
(
variable
=
sv
,
value
=
sv
.
container
,
mutable
=
True
,
borrow
=
True
,
update
=
update_d
[
sv
],
shared
=
True
)
si
=
In
(
variable
=
sv
,
value
=
sv
.
container
,
mutable
=
True
,
borrow
=
True
,
update
=
update_d
[
sv
],
shared
=
True
)
else
:
si
=
In
(
variable
=
sv
,
value
=
sv
.
container
,
mutable
=
False
,
borrow
=
True
,
shared
=
True
)
si
=
In
(
variable
=
sv
,
value
=
sv
.
container
,
mutable
=
False
,
borrow
=
True
,
shared
=
True
)
inputs
.
append
(
si
)
return
orig_function
(
inputs
,
cloned_outputs
,
mode
,
accept_inplace
=
accept_inplace
,
name
=
name
,
profile
=
profile
)
...
...
@@ -449,7 +446,7 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
#if isinstance(param, Value):
#return In(variable=param)
#raise NotImplementedError()
if
isinstance
(
param
,
Variable
):
#
N.B. includes Value and SharedVariable
if
isinstance
(
param
,
Variable
):
#
N.B. includes Value and SharedVariable
return
In
(
variable
=
param
,
strict
=
strict
,
allow_downcast
=
allow_downcast
)
elif
isinstance
(
param
,
Param
):
return
In
(
...
...
@@ -458,9 +455,9 @@ def _pfunc_param_to_in(param, strict=False, allow_downcast=None):
value
=
param
.
default
,
mutable
=
param
.
mutable
,
strict
=
param
.
strict
,
borrow
=
param
.
borrow
,
borrow
=
param
.
borrow
,
allow_downcast
=
param
.
allow_downcast
,
implicit
=
param
.
implicit
)
implicit
=
param
.
implicit
)
raise
TypeError
(
'Unknown parameter type:
%
s'
%
type
(
param
))
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论