Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cfecf720
提交
cfecf720
authored
1月 21, 2010
作者:
Pascal Lamblin
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Prepare for default updates of shared variable. Keyword "no_default_updates" in…
Prepare for default updates of shared variable. Keyword "no_default_updates" in theano.function added.
上级
1997674d
隐藏空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
188 行增加
和
75 行删除
+188
-75
function.py
theano/compile/function.py
+9
-2
pfunc.py
theano/compile/pfunc.py
+174
-73
sharedvalue.py
theano/compile/sharedvalue.py
+5
-0
没有找到文件。
theano/compile/function.py
浏览文件 @
cfecf720
...
@@ -10,7 +10,8 @@ from function_module import orig_function
...
@@ -10,7 +10,8 @@ from function_module import orig_function
from
pfunc
import
pfunc
from
pfunc
import
pfunc
from
numpy
import
any
#for to work in python 2.4
from
numpy
import
any
#for to work in python 2.4
def
function
(
inputs
,
outputs
=
None
,
mode
=
None
,
updates
=
[],
givens
=
[],
accept_inplace
=
False
,
name
=
None
):
def
function
(
inputs
,
outputs
=
None
,
mode
=
None
,
updates
=
[],
givens
=
[],
no_default_updates
=
False
,
accept_inplace
=
False
,
name
=
None
):
"""
"""
Return a callable object that will calculate `outputs` from `inputs`.
Return a callable object that will calculate `outputs` from `inputs`.
...
@@ -31,7 +32,12 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl
...
@@ -31,7 +32,12 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl
and Var2 in each pair must have the same Type.
and Var2 in each pair must have the same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces
:param givens: specific substitutions to make in the computation graph (Var2 replaces
Var1).
Var1).
:type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables.
If False (default), perform them all. Else, perform automatic updates on all Variables
that are neither in "updates" nor in "no_default_updates".
:param name: an optional name for this function. The profile mode will print the time spent in this function.
:param name: an optional name for this function. The profile mode will print the time spent in this function.
...
@@ -65,4 +71,5 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl
...
@@ -65,4 +71,5 @@ def function(inputs, outputs=None, mode=None, updates=[], givens=[], accept_inpl
mode
=
mode
,
mode
=
mode
,
updates
=
updates
,
updates
=
updates
,
givens
=
givens
,
givens
=
givens
,
no_default_updates
=
no_default_updates
,
accept_inplace
=
accept_inplace
,
name
=
name
)
accept_inplace
=
accept_inplace
,
name
=
name
)
theano/compile/pfunc.py
浏览文件 @
cfecf720
...
@@ -33,7 +33,8 @@ class Param(object):
...
@@ -33,7 +33,8 @@ class Param(object):
self
.
strict
=
strict
self
.
strict
=
strict
self
.
implicit
=
implicit
self
.
implicit
=
implicit
def
pfunc
(
params
,
outputs
=
None
,
mode
=
None
,
updates
=
[],
givens
=
[],
accept_inplace
=
False
,
name
=
None
):
def
pfunc
(
params
,
outputs
=
None
,
mode
=
None
,
updates
=
[],
givens
=
[],
no_default_updates
=
False
,
accept_inplace
=
False
,
name
=
None
):
"""Function-constructor for graphs with shared variables.
"""Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances.
:type params: list of either Variable or Param instances.
...
@@ -53,7 +54,12 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
...
@@ -53,7 +54,12 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
and Var2 in each pair must have the same Type.
and Var2 in each pair must have the same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces
:param givens: specific substitutions to make in the computation graph (Var2 replaces
Var1).
Var1).
:type no_default_updates: either bool or list of Variables
:param no_default_updates: if True, do not perform any automatic update on Variables.
If False (default), perform them all. Else, perform automatic updates on all Variables
that are neither in "updates" nor in "no_default_updates".
:param name: an optional name for this fct. If used, the profile mode will print the time spent in this fct.
:param name: an optional name for this fct. If used, the profile mode will print the time spent in this fct.
...
@@ -86,11 +92,61 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
...
@@ -86,11 +92,61 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
if
not
isinstance
(
params
,(
list
,
tuple
)):
if
not
isinstance
(
params
,(
list
,
tuple
)):
raise
Exception
(
"in pfunc() the first argument must be a list or a tuple"
)
raise
Exception
(
"in pfunc() the first argument must be a list or a tuple"
)
# initialize the clone_d mapping with the `givens` argument
clone_d
=
{}
clone_d
=
{}
def
v_clone
(
v
):
# Updates as list and dictionary.
return
_v_clone
(
v
,
clone_d
)
# They will also store the 'default_update' expressions applicable.
# The dictionary is used to look up the existence of the keys, and to store
# the final (cloned) update expressions.
# The list of pairs is used to iterate in a consistent order while adding
# new pairs.
update_d
=
{}
update_expr
=
[]
# list of shared inputs that are used as inputs of the graph
shared_inputs
=
[]
def
clone_v_get_shared_updates
(
v
):
'''Clone a variable and its inputs, until all are in clone_d.
Also appends all shared variables met along the way to shared_inputs,
and their default_update (if applicable) to update_d and update_expr.
'''
assert
v
is
not
None
if
v
.
owner
:
clone_a
(
v
.
owner
)
elif
isinstance
(
v
,
SharedVariable
):
if
v
not
in
shared_inputs
:
shared_inputs
.
append
(
v
)
if
hasattr
(
v
,
'default_update'
):
# Check that v should not be excluded from the default updates list
if
no_default_updates
is
False
or
\
(
isinstance
(
no_default_updates
,
list
)
and
\
v
not
in
no_default_updates
):
# Do not use default_update if a "real" update was provided
if
v
not
in
update_d
:
v_update
=
v
.
filter_update
(
v
.
default_update
)
if
v_update
.
type
!=
v
.
type
:
raise
TypeError
(
'an update must have the same type as the original shared variable'
,
(
v
,
v
.
type
,
v_update
,
v_update
.
type
))
update_d
[
v
]
=
v_update
update_expr
.
append
((
v
,
v_update
))
return
clone_d
.
setdefault
(
v
,
v
)
def
clone_a
(
a
):
if
a
is
None
:
return
None
if
a
not
in
clone_d
:
for
i
in
a
.
inputs
:
clone_v_get_shared_updates
(
i
)
clone_d
[
a
]
=
a
.
clone_with_new_inputs
([
clone_d
[
i
]
for
i
in
a
.
inputs
])
for
old_o
,
new_o
in
zip
(
a
.
outputs
,
clone_d
[
a
]
.
outputs
):
clone_d
.
setdefault
(
old_o
,
new_o
)
return
clone_d
[
a
]
#def v_clone(v):
# return _v_clone(v, clone_d)
# initialize the clone_d mapping with the `givens` argument
try
:
try
:
givens
=
givens
.
items
()
# converts a dictionary to the sort of list that we want.
givens
=
givens
.
items
()
# converts a dictionary to the sort of list that we want.
except
:
except
:
...
@@ -101,11 +157,9 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
...
@@ -101,11 +157,9 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
if
not
isinstance
(
v_repl
,
Variable
):
if
not
isinstance
(
v_repl
,
Variable
):
v_repl
=
shared
(
v_repl
)
v_repl
=
shared
(
v_repl
)
assert
v_orig
not
in
clone_d
assert
v_orig
not
in
clone_d
clone_d
[
v_orig
]
=
v_clone
(
v_repl
)
clone_d
[
v_orig
]
=
clone_v_get_shared_updates
(
v_repl
)
# transform params into theano.compile.In objects.
# transform params into theano.compile.In objects.
#
# call theano.function
inputs
=
[
_pfunc_param_to_in
(
p
)
for
p
in
params
]
inputs
=
[
_pfunc_param_to_in
(
p
)
for
p
in
params
]
#Switch inputs to cloned variables
#Switch inputs to cloned variables
...
@@ -113,101 +167,148 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
...
@@ -113,101 +167,148 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], accept_inplace
for
i
,
iv
in
zip
(
inputs
,
input_variables
):
for
i
,
iv
in
zip
(
inputs
,
input_variables
):
i
.
variable
=
iv
i
.
variable
=
iv
set_of_param_variables
=
set
(
input_variables
)
#
set_of_param_variables = set(input_variables)
# It was decided, as a first step, to prevent shared variables from being
# It was decided, as a first step, to prevent shared variables from being
# used as function inputs. Although it is technically possible, it is also
# used as function inputs. Although it is technically possible, it is also
# potentially ambiguous and dangerous. This restriction may be revisited in
# potentially ambiguous and dangerous. This restriction may be revisited in
# the future if there is a need for such a feature.
# the future if there is a need for such a feature.
if
numpy
.
any
([
isinstance
(
v
,
SharedVariable
)
for
v
in
set_of_param
_variables
]):
if
numpy
.
any
([
isinstance
(
v
,
SharedVariable
)
for
v
in
input
_variables
]):
raise
TypeError
(
'Cannot use a shared variable (
%
s) as explicit input '
raise
TypeError
(
'Cannot use a shared variable (
%
s) as explicit input '
%
v
)
%
v
)
# Fill update_d and update_expr with provided updates
for
(
store_into
,
update_val
)
in
iter_over_pairs
(
updates
):
if
not
isinstance
(
store_into
,
SharedVariable
):
raise
TypeError
(
'update target must be a SharedVariable'
,
store_into
)
if
store_into
in
update_d
:
raise
ValueError
(
'this shared variable already has an update expression'
,
(
store_into
,
update_d
[
store_into
]))
update_val
=
store_into
.
filter_update
(
update_val
)
if
update_val
.
type
!=
store_into
.
type
:
raise
TypeError
(
'an update must have the same type as the original shared variable'
,
(
store_into
,
store_into
.
type
,
update_val
,
update_val
.
type
))
update_d
[
store_into
]
=
update_val
update_expr
.
append
((
store_into
,
update_val
))
# computed_list is a list of output variables (which will be extended later)
# computed_list is a list of output variables (which will be extended later)
computed_list
=
[]
#computed_list = []
# Elements of "outputs" are here cloned to "cloned_outputs"
if
isinstance
(
outputs
,
list
):
if
isinstance
(
outputs
,
list
):
cloned_outputs
=
[]
cloned_outputs
=
[]
for
v
in
outputs
:
for
v
in
outputs
:
if
isinstance
(
v
,
Variable
):
if
isinstance
(
v
,
Variable
):
cloned_v
=
v_clone
(
v
)
cloned_v
=
clone_v_get_shared_updates
(
v
)
cloned_outputs
.
append
(
cloned_v
)
cloned_outputs
.
append
(
cloned_v
)
elif
isinstance
(
v
,
Out
):
elif
isinstance
(
v
,
Out
):
cloned_v
=
v_clone
(
v
.
variable
)
cloned_v
=
clone_v_get_shared_updates
(
v
.
variable
)
cloned_outputs
.
append
(
Out
(
cloned_v
,
borrow
=
v
.
borrow
))
cloned_outputs
.
append
(
Out
(
cloned_v
,
borrow
=
v
.
borrow
))
else
:
else
:
raise
TypeError
(
'outputs must be theano Variable or Out instances'
,
v
)
raise
TypeError
(
'outputs must be theano Variable or Out instances'
,
v
)
computed_list
.
append
(
cloned_v
)
#
computed_list.append(cloned_v)
else
:
else
:
if
isinstance
(
outputs
,
Variable
):
if
isinstance
(
outputs
,
Variable
):
cloned_v
=
v_clone
(
outputs
)
cloned_v
=
clone_v_get_shared_updates
(
outputs
)
cloned_outputs
=
cloned_v
cloned_outputs
=
cloned_v
computed_list
.
append
(
cloned_v
)
#
computed_list.append(cloned_v)
elif
isinstance
(
outputs
,
Out
):
elif
isinstance
(
outputs
,
Out
):
cloned_v
=
v_clone
(
outputs
.
variable
)
cloned_v
=
clone_v_get_shared_updates
(
outputs
.
variable
)
cloned_outputs
=
Out
(
cloned_v
,
borrow
=
outputs
.
borrow
)
cloned_outputs
=
Out
(
cloned_v
,
borrow
=
outputs
.
borrow
)
computed_list
.
append
(
cloned_v
)
#
computed_list.append(cloned_v)
elif
outputs
is
None
:
elif
outputs
is
None
:
cloned_outputs
=
[]
# TODO: return None
cloned_outputs
=
[]
# TODO: return None
else
:
else
:
raise
TypeError
(
'output must be a theano Variable or Out instance (or list of them)'
,
outputs
)
raise
TypeError
(
'output must be a theano Variable or Out instance (or list of them)'
,
outputs
)
# Add update values as quantities that must be computed.
# Iterate over update_expr, cloning its elements, and updating
# Here, we
# shared_inputs, update_d and update_expr from the SharedVariables
# - extend the computed_list
# we discover.
# - replace some update expressions (but update keys remain)
# If the variable to be updated is a shared variable not already
new_updates
=
{}
# in shared_inputs, add it.
for
(
store_into
,
update_val
)
in
iter_over_pairs
(
updates
):
# Note: we extend update_expr while iterating over it.
if
not
isinstance
(
store_into
,
SharedVariable
):
i
=
0
raise
TypeError
(
'update target must be a SharedVariable'
,
store_into
)
while
i
<
len
(
update_expr
):
if
store_into
in
new_updates
:
v
,
v_update
=
update_expr
[
i
]
raise
ValueError
(
'this shared variable already has an update expression'
,
cloned_v_update
=
clone_v_get_shared_updates
(
v_update
)
(
store_into
,
new_updates
[
store_into
]))
update_d
[
v
]
=
cloned_v_update
update_val
=
v_clone
(
store_into
.
filter_update
(
update_val
))
if
isinstance
(
v
,
SharedVariable
)
and
v
not
in
shared_inputs
:
if
update_val
.
type
!=
store_into
.
type
:
shared_inputs
.
append
(
v
)
raise
TypeError
(
'an update must have the same type as the original shared variable'
,
i
+=
1
(
store_into
,
store_into
.
type
,
update_val
,
update_val
.
type
))
#updates = update_d #?
computed_list
.
append
(
update_val
)
for
sv
in
shared_inputs
:
new_updates
[
store_into
]
=
update_val
if
sv
in
update_d
:
updates
=
new_updates
si
=
In
(
variable
=
sv
,
value
=
sv
.
container
,
mutable
=
True
,
update
=
update_d
[
sv
])
# Obtain all inputs we need to compute what we want.
graph_inputs
=
graph
.
inputs
(
computed_list
,
blockers
=
set_of_param_variables
)
shared_inputs
=
[
i
for
i
in
graph_inputs
if
isinstance
(
i
,
SharedVariable
)]
# Add shared variables (from shared_inputs) that were not already present in the list of
# params.
inputs
+=
[
In
(
variable
=
si
,
value
=
si
.
container
,
mutable
=
False
)
for
si
in
shared_inputs
if
si
not
in
set_of_param_variables
]
del
shared_inputs
# Iterate over the updates, which are either pairs
# (shared_var, expressionvariable), or a similar dictionary.
# For each shared_variable, find the In instance that we created for it in the inputs list.
# Give that In instance (in_sv) an update expression.
#
# I think we usually want to set these Inputs to be mutable,
# ... are there exceptions?
for
(
sv
,
new_val
)
in
iter_over_pairs
(
updates
):
in_sv
=
None
for
in_sv_i
in
inputs
:
if
in_sv_i
.
variable
is
sv
:
assert
in_sv
is
None
in_sv
=
in_sv_i
if
in_sv
is
None
:
# This variable was not used anywhere and thus is not in the input
# list yet.
inputs
.
append
(
In
(
variable
=
sv
,
value
=
sv
.
container
,
mutable
=
True
,
update
=
new_val
))
else
:
else
:
in_sv
.
update
=
new_val
si
=
In
(
variable
=
sv
,
value
=
sv
.
container
,
mutable
=
False
)
in_sv
.
mutable
=
True
inputs
.
append
(
si
)
return
orig_function
(
inputs
,
cloned_outputs
,
mode
,
accept_inplace
=
accept_inplace
,
name
=
name
)
if
0
:
# Add update values as quantities that must be computed.
# Here, we
# - extend the computed_list
# - replace some update expressions (but update keys remain)
new_updates
=
{}
for
(
store_into
,
update_val
)
in
iter_over_pairs
(
updates
):
if
not
isinstance
(
store_into
,
SharedVariable
):
raise
TypeError
(
'update target must be a SharedVariable'
,
store_into
)
if
store_into
in
new_updates
:
raise
ValueError
(
'this shared variable already has an update expression'
,
(
store_into
,
new_updates
[
store_into
]))
update_val
=
v_clone
(
store_into
.
filter_update
(
update_val
))
if
update_val
.
type
!=
store_into
.
type
:
raise
TypeError
(
'an update must have the same type as the original shared variable'
,
(
store_into
,
store_into
.
type
,
update_val
,
update_val
.
type
))
computed_list
.
append
(
update_val
)
new_updates
[
store_into
]
=
update_val
updates
=
new_updates
# Obtain all inputs we need to compute what we want.
graph_inputs
=
graph
.
inputs
(
computed_list
,
blockers
=
set_of_param_variables
)
shared_inputs
=
[
i
for
i
in
graph_inputs
if
isinstance
(
i
,
SharedVariable
)]
# Add shared variables (from shared_inputs) that were not already present in the list of
# params.
inputs
+=
[
In
(
variable
=
si
,
value
=
si
.
container
,
mutable
=
False
)
for
si
in
shared_inputs
if
si
not
in
set_of_param_variables
]
del
shared_inputs
# Iterate over the updates, which are either pairs
# (shared_var, expressionvariable), or a similar dictionary.
# For each shared_variable, find the In instance that we created for it in the inputs list.
# Give that In instance (in_sv) an update expression.
#
# I think we usually want to set these Inputs to be mutable,
# ... are there exceptions?
for
(
sv
,
new_val
)
in
iter_over_pairs
(
updates
):
in_sv
=
None
for
in_sv_i
in
inputs
:
if
in_sv_i
.
variable
is
sv
:
assert
in_sv
is
None
in_sv
=
in_sv_i
if
in_sv
is
None
:
# This variable was not used anywhere and thus is not in the input
# list yet.
inputs
.
append
(
In
(
variable
=
sv
,
value
=
sv
.
container
,
mutable
=
True
,
update
=
new_val
))
else
:
in_sv
.
update
=
new_val
in_sv
.
mutable
=
True
return
orig_function
(
inputs
,
cloned_outputs
,
mode
,
accept_inplace
=
accept_inplace
,
name
=
name
)
return
orig_function
(
inputs
,
cloned_outputs
,
mode
,
accept_inplace
=
accept_inplace
,
name
=
name
)
def
_pfunc_param_to_in
(
param
):
def
_pfunc_param_to_in
(
param
):
if
isinstance
(
param
,
Constant
):
if
isinstance
(
param
,
Constant
):
...
...
theano/compile/sharedvalue.py
浏览文件 @
cfecf720
...
@@ -28,6 +28,11 @@ class SharedVariable(Variable):
...
@@ -28,6 +28,11 @@ class SharedVariable(Variable):
:type: `Container`
:type: `Container`
"""
"""
# default_update
# If this member is present, its value will be used as the "update" for
# this Variable, unless another update value has been passed to "function",
# or the "no_default_updates" list passed to "function" contains it.
def
__init__
(
self
,
name
,
type
,
value
,
strict
,
container
=
None
):
def
__init__
(
self
,
name
,
type
,
value
,
strict
,
container
=
None
):
"""
"""
:param name: The name for this variable (see `Variable`).
:param name: The name for this variable (see `Variable`).
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论