Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bc1ffeea
提交
bc1ffeea
authored
10月 05, 2009
作者:
James Bergstra
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
added givens parameter to compile.sandbox.pfunc
上级
88f7f401
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
107 行增加
和
20 行删除
+107
-20
pfunc.py
theano/compile/sandbox/pfunc.py
+91
-20
test_pfunc.py
theano/compile/sandbox/tests/test_pfunc.py
+16
-0
没有找到文件。
theano/compile/sandbox/pfunc.py
浏览文件 @
bc1ffeea
...
@@ -32,7 +32,7 @@ class Param(object):
...
@@ -32,7 +32,7 @@ class Param(object):
self
.
strict
=
strict
self
.
strict
=
strict
self
.
implicit
=
implicit
self
.
implicit
=
implicit
def
pfunc
(
params
,
outputs
=
None
,
mode
=
None
,
updates
=
[]):
def
pfunc
(
params
,
outputs
=
None
,
mode
=
None
,
updates
=
[]
,
givens
=
[]
):
"""Function-constructor for graphs with shared variables.
"""Function-constructor for graphs with shared variables.
:type params: list of either Variable or Param instances.
:type params: list of either Variable or Param instances.
...
@@ -42,38 +42,72 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
...
@@ -42,38 +42,72 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
:type outputs: list of Variables or Out instances
:type outputs: list of Variables or Out instances
:param outputs: expressions to compute
:param outputs: expressions to compute
:type mode: string or `theano.compile.Mode` instance.
:param mode: compilation mode
:param mode: compilation mode
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or dict.
:type updates: iterable over pairs (shared_variable, new_expression). List, tuple or dict.
:param updates: update the values for SharedVariable inputs according to these expressions
:param updates: update the values for SharedVariable inputs according to these expressions
:type givens: iterable over pairs (Var1, Var2) of Variables. List, tuple or dict. The Var1
and Var2 in each pair must have the same Type.
:param givens: specific substitutions to make in the computation graph (Var2 replaces
Var1).
:rtype: theano.compile.Function
:rtype: theano.compile.Function
:returns: a callable object that will compute the outputs (given the inputs)
:returns: a callable object that will compute the outputs (given the inputs)
and update the implicit function arguments according to the `updates`.
and update the implicit function arguments according to the `updates`.
:note: Regarding givens: Be careful to make sure that these substitutions are
independent--behaviour when Var1 of one pair appears in the graph leading to Var2 in
another expression is undefined. Replacements specified with givens are different from
optimizations in that Var2 is not expected to be equivalent to Var1.
"""
"""
# Note: in its early design, pfunc was also meant to accept another
#
# parameter, 'givens'. This was a dictionary assigning some specific
# This function works by cloning the graph (except for the inputs), and then shipping it
# values to some of the Variable in the graph, so as to allow the
# off to compile.function
# function to possibly make some optimizations at compile time.
# (There it will be cloned again, unnecessarily, because it doesn't know that we already
# In the end, this feature was not kept, because it was not obvious
# cloned it.)
# how to implement it, nor whether it was really needed.
#
# If one wants to add this feature in the future, it may be easier instead
# First, it clones the replacements named in the givens argument, and points each Var1 to
# to add a new parameter to 'Param' to indicate that some input of the
# the clone of Var2.
# function is taking a specific constant value.
# Then it sets the inputs in the clone dictionary.
# After these steps, we are assuming that the clone dictionary contains all the inputs to
if
not
isinstance
(
outputs
,
list
):
# the computation graph.
computed_list
=
[
outputs
]
#
else
:
# Then it clones the outputs and the update expressions. This rebuilds a computation graph
# Copy list (because it may be extended later).
# from the inputs and the givens.
computed_list
=
[
out
for
out
in
outputs
]
#
# initialize the clone_d mapping with the `givens` argument
clone_d
=
{}
def
v_clone
(
v
):
return
_v_clone
(
v
,
clone_d
)
try
:
givens
=
givens
.
items
()
# converts a dictionary to the sort of list that we want.
except
:
pass
for
v_orig
,
v_repl
in
givens
:
if
not
isinstance
(
v_orig
,
Variable
):
raise
TypeError
(
'given keys must be Variable'
,
v_orig
)
if
not
isinstance
(
v_repl
,
Variable
):
v_repl
=
shared
(
v_repl
)
assert
v_orig
not
in
clone_d
clone_d
[
v_orig
]
=
v_clone
(
v_repl
)
# transform params into theano.compile.In objects.
# transform params into theano.compile.In objects.
#
#
# call theano.function
# call theano.function
inputs
=
[
_pfunc_param_to_in
(
p
)
for
p
in
params
]
inputs
=
[
_pfunc_param_to_in
(
p
)
for
p
in
params
]
set_of_param_variables
=
set
([
i
.
variable
for
i
in
inputs
])
#Switch inputs to cloned variables
input_variables
=
[
clone_d
.
setdefault
(
i
.
variable
,
i
.
variable
)
for
i
in
inputs
]
for
i
,
iv
in
zip
(
inputs
,
input_variables
):
i
.
variable
=
iv
set_of_param_variables
=
set
(
input_variables
)
# It was decided, as a first step, to prevent shared variables from being
# It was decided, as a first step, to prevent shared variables from being
# used as function inputs. Although it is technically possible, it is also
# used as function inputs. Although it is technically possible, it is also
...
@@ -83,11 +117,28 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
...
@@ -83,11 +117,28 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
raise
TypeError
(
'Cannot use a shared variable (
%
s) as explicit input '
raise
TypeError
(
'Cannot use a shared variable (
%
s) as explicit input '
%
v
)
%
v
)
# computed_list is a list of output variables
if
isinstance
(
outputs
,
list
):
for
v
in
outputs
:
if
not
isinstance
(
v
,
Variable
):
raise
TypeError
(
'outputs must be theano Variable instances'
,
v
)
# Copy list (because it may be extended later).
computed_list
=
[
v_clone
(
o
)
for
o
in
outputs
]
cloned_outputs
=
list
(
computed_list
)
else
:
if
not
isinstance
(
outputs
,
Variable
):
raise
TypeError
(
'output must be a theano Variable instance'
,
outputs
)
cloned_outputs
=
v_clone
(
outputs
)
computed_list
=
[
cloned_outputs
]
# Add update values as quantities that must be computed.
# Add update values as quantities that must be computed.
# Here, we
# - extend the computed_list
# - replace some update expressions (but update keys remain)
new_updates
=
{}
new_updates
=
{}
for
(
store_into
,
update_val
)
in
iter_over_pairs
(
updates
):
for
(
store_into
,
update_val
)
in
iter_over_pairs
(
updates
):
assert
isinstance
(
store_into
,
SharedVariable
)
assert
isinstance
(
store_into
,
SharedVariable
)
update_val
=
store_into
.
filter_update
(
update_val
)
update_val
=
v_clone
(
store_into
.
filter_update
(
update_val
)
)
if
update_val
.
type
!=
store_into
.
type
:
if
update_val
.
type
!=
store_into
.
type
:
raise
TypeError
(
'an update must have the same type as the original shared variable'
,
raise
TypeError
(
'an update must have the same type as the original shared variable'
,
(
store_into
,
store_into
.
type
,
(
store_into
,
store_into
.
type
,
...
@@ -98,7 +149,7 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
...
@@ -98,7 +149,7 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
# Obtain all inputs we need to compute what we want.
# Obtain all inputs we need to compute what we want.
graph_inputs
=
graph
.
inputs
(
computed_list
,
graph_inputs
=
graph
.
inputs
(
computed_list
,
blockers
=
set
([
i
.
variable
for
i
in
inputs
])
)
blockers
=
set
_of_param_variables
)
shared_inputs
=
[
i
for
i
in
graph_inputs
if
isinstance
(
i
,
SharedVariable
)]
shared_inputs
=
[
i
for
i
in
graph_inputs
if
isinstance
(
i
,
SharedVariable
)]
...
@@ -131,7 +182,7 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
...
@@ -131,7 +182,7 @@ def pfunc(params, outputs=None, mode=None, updates=[]):
in_sv
.
update
=
new_val
in_sv
.
update
=
new_val
in_sv
.
mutable
=
True
in_sv
.
mutable
=
True
return
function
(
inputs
,
outputs
,
mode
,
accept_inplace
=
False
)
return
function
(
inputs
,
cloned_
outputs
,
mode
,
accept_inplace
=
False
)
def
_pfunc_param_to_in
(
param
):
def
_pfunc_param_to_in
(
param
):
if
isinstance
(
param
,
Constant
):
if
isinstance
(
param
,
Constant
):
...
@@ -168,3 +219,23 @@ def iter_over_pairs(pairs):
...
@@ -168,3 +219,23 @@ def iter_over_pairs(pairs):
return
pairs
.
iteritems
()
return
pairs
.
iteritems
()
else
:
else
:
return
pairs
return
pairs
#TODO: Make these non-recursive so they can deal with larger graphs
def
_a_clone
(
a
,
dct
):
if
a
is
None
:
return
None
if
a
not
in
dct
:
for
i
in
a
.
inputs
:
_v_clone
(
i
,
dct
)
dct
[
a
]
=
a
.
clone_with_new_inputs
([
dct
[
i
]
for
i
in
a
.
inputs
])
for
old_o
,
new_o
in
zip
(
a
.
outputs
,
dct
[
a
]
.
outputs
):
dct
.
setdefault
(
old_o
,
new_o
)
return
dct
[
a
]
def
_v_clone
(
v
,
dct
):
assert
v
is
not
None
if
v
.
owner
:
_a_clone
(
v
.
owner
,
dct
)
return
dct
.
setdefault
(
v
,
v
)
theano/compile/sandbox/tests/test_pfunc.py
浏览文件 @
bc1ffeea
...
@@ -193,6 +193,22 @@ class Test_pfunc(unittest.TestCase):
...
@@ -193,6 +193,22 @@ class Test_pfunc(unittest.TestCase):
inc_by_y
()
inc_by_y
()
self
.
failUnless
(
x
.
value
==
1
)
self
.
failUnless
(
x
.
value
==
1
)
def
test_givens
(
self
):
x
=
shared
(
0
)
assign
=
pfunc
([],
x
,
givens
=
{
x
:
3
})
assert
assign
()
==
3
assert
x
.
value
==
0
y
=
tensor
.
ivector
()
f
=
pfunc
([
y
],
y
*
x
,
givens
=
{
x
:
6
})
assert
numpy
.
all
(
f
([
1
,
1
,
1
])
==
[
6
,
6
,
6
])
assert
x
.
value
==
0
z
=
tensor
.
ivector
()
c
=
z
*
y
f
=
pfunc
([
y
],
c
+
7
,
givens
=
{
z
:
numpy
.
asarray
([
4
,
4
,
4
],
dtype
=
'int32'
)})
assert
numpy
.
all
(
f
([
1
,
1
,
1
])
==
[
11
,
11
,
11
])
assert
x
.
value
==
0
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
theano
.
compile
.
mode
.
default_mode
=
'FAST_COMPILE'
theano
.
compile
.
mode
.
default_mode
=
'FAST_COMPILE'
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论