Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
6c3d8e26
提交
6c3d8e26
authored
6月 15, 2015
作者:
abergeron
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3002 from carriepl/scan_index_error
[CRASH] Scan index error
上级
32bc96d7
7b984d13
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
290 行增加
和
182 行删除
+290
-182
scan_op.py
theano/scan_module/scan_op.py
+185
-163
test_scan.py
theano/scan_module/tests/test_scan.py
+105
-19
没有找到文件。
theano/scan_module/scan_op.py
浏览文件 @
6c3d8e26
...
@@ -229,6 +229,11 @@ class Scan(PureOp):
...
@@ -229,6 +229,11 @@ class Scan(PureOp):
self
.
_cmodule_key
=
gof
.
CLinker
()
.
cmodule_key_
(
local_fgraph
,
[])
self
.
_cmodule_key
=
gof
.
CLinker
()
.
cmodule_key_
(
local_fgraph
,
[])
self
.
_hash_inner_graph
=
hash
(
self
.
_cmodule_key
)
self
.
_hash_inner_graph
=
hash
(
self
.
_cmodule_key
)
# Compute mappings between outer inputs, outer outputs, inner
# inputs and inner outputs to determine with variables are associated
# with the same states.
self
.
var_mappings
=
self
.
get_oinp_iinp_iout_oout_mappings
()
def
validate_inner_graph
(
self
):
def
validate_inner_graph
(
self
):
""" Perform some elementary validations on the inner graph to ensure
""" Perform some elementary validations on the inner graph to ensure
that it is coherent.
that it is coherent.
...
@@ -237,14 +242,11 @@ class Scan(PureOp):
...
@@ -237,14 +242,11 @@ class Scan(PureOp):
# For every recurrent output, iterate over the associated inner
# For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype
# inputs and output and ensure that they have the same dtype
nb_recurr_outputs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
nb_recurr_outputs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
outer_iidx_from_outer_oidx
=
self
.
get_outer_iidx_from_outer_oidx_seq
()
for
outer_oidx
in
range
(
nb_recurr_outputs
):
for
outer_oidx
in
range
(
nb_recurr_outputs
):
outer_iidx
=
outer_iidx_from_outer_oidx
[
outer_oidx
]
inner_iidxs
=
self
.
var_mappings
[
'inner_inp_from_outer_out'
][
outer_oidx
]
inner_oidxs
=
self
.
var_mappings
[
'inner_out_from_outer_out'
][
outer_oidx
]
inner_iidxs
=
self
.
get_inner_iidx_from_outer_iidx
(
outer_iidx
)
inner_oidxs
=
self
.
get_inner_oidx_from_outer_oidx
(
outer_oidx
)
for
(
inner_iidx
,
inner_oidx
)
in
itertools
.
product
(
inner_iidxs
,
for
(
inner_iidx
,
inner_oidx
)
in
itertools
.
product
(
inner_iidxs
,
inner_oidxs
):
inner_oidxs
):
...
@@ -303,13 +305,19 @@ class Scan(PureOp):
...
@@ -303,13 +305,19 @@ class Scan(PureOp):
def
__setstate__
(
self
,
d
):
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
self
.
__dict__
.
update
(
d
)
self
.
validate_inner_graph
()
if
"allow_gc"
not
in
self
.
__dict__
:
if
"allow_gc"
not
in
self
.
__dict__
:
self
.
allow_gc
=
True
self
.
allow_gc
=
True
self
.
info
[
'allow_gc'
]
=
True
self
.
info
[
'allow_gc'
]
=
True
if
not
hasattr
(
self
,
'gpua'
):
if
not
hasattr
(
self
,
'gpua'
):
self
.
gpua
=
False
self
.
gpua
=
False
self
.
info
[
'gpua'
]
=
False
self
.
info
[
'gpua'
]
=
False
if
not
hasattr
(
self
,
'var_mappings'
):
# Generate the mappings between inner and outer inputs and outputs
# if they haven't already been generated.
self
.
var_mappings
=
self
.
get_oinp_iinp_iout_oout_mappings
()
# Ensure that the graph associated with the inner function is valid.
self
.
validate_inner_graph
()
def
make_node
(
self
,
*
inputs
):
def
make_node
(
self
,
*
inputs
):
"""
"""
...
@@ -1470,66 +1478,6 @@ class Scan(PureOp):
...
@@ -1470,66 +1478,6 @@ class Scan(PureOp):
scan_outs
.
append
((
Shape_i
(
0
)(
o
),)
+
x
[
1
:])
scan_outs
.
append
((
Shape_i
(
0
)(
o
),)
+
x
[
1
:])
return
scan_outs
return
scan_outs
def
get_input_pos
(
self
,
output_index
):
""" For a given ``output_index``, an index in the inner outputs of
scan, find a corresponding first index in the inner inputs of scan
"""
ipos
=
self
.
n_seqs
opos
=
output_index
for
otaps
,
itaps
in
zip
(
self
.
mitmot_out_taps
(),
self
.
mitmot_taps
()):
if
len
(
otaps
)
>
opos
:
return
ipos
else
:
opos
=
opos
-
len
(
otaps
)
ipos
+=
len
(
itaps
)
for
dx
,
taps
in
enumerate
(
self
.
mitsot_taps
()):
if
opos
==
0
:
return
ipos
else
:
opos
=
opos
-
1
ipos
+=
len
(
taps
)
if
opos
<
self
.
info
[
'n_sit_sot'
]:
return
ipos
+
opos
else
:
return
-
1
def
get_output_pos
(
self
,
input_index
):
""" For a given ``input_index``, an index in the inner inputs of
scan, find a corresponding first index in the inner outputs of scan
"""
ipos
=
input_index
opos
=
0
for
otaps
,
itaps
in
zip
(
self
.
mitmot_out_taps
(),
self
.
mitmot_taps
()):
if
len
(
itaps
)
>
ipos
:
return
opos
else
:
opos
+=
len
(
otaps
)
ipos
-=
len
(
itaps
)
for
dx
,
taps
in
enumerate
(
self
.
mitsot_taps
()):
if
len
(
taps
)
>
ipos
:
return
opos
else
:
opos
+=
1
ipos
-=
len
(
taps
)
if
ipos
<
self
.
info
[
'n_sit_sot'
]:
return
ipos
+
opos
else
:
return
-
1
def
get_output_slice_idx
(
self
,
output_index
):
""" For an ``output_index``, an index in the outter ouputs of scan,
find a corresponding index in the inner outputs of scan.
"""
ipos
=
0
opos
=
output_index
for
otaps
in
zip
(
self
.
mitmot_out_taps
()):
if
len
(
otaps
)
>
0
:
return
ipos
else
:
opos
=
opos
-
1
ipos
+=
len
(
otaps
)
return
ipos
+
opos
def
inner_connection_pattern
(
self
):
def
inner_connection_pattern
(
self
):
""" Returns the connection pattern of scan's inner function
""" Returns the connection pattern of scan's inner function
"""
"""
...
@@ -1616,10 +1564,10 @@ class Scan(PureOp):
...
@@ -1616,10 +1564,10 @@ class Scan(PureOp):
# and inner outputs and, if one such pair of inner variables is
# and inner outputs and, if one such pair of inner variables is
# connected than the pair of outer variables is connected.
# connected than the pair of outer variables is connected.
for
outer_oidx
in
range
(
len
(
node
.
outputs
)):
for
outer_oidx
in
range
(
len
(
node
.
outputs
)):
inner_oidxs
=
self
.
get_inner_oidx_from_outer_oidx
(
outer_oidx
)
inner_oidxs
=
self
.
var_mappings
[
'inner_out_from_outer_out'
][
outer_oidx
]
for
outer_iidx
in
range
(
len
(
node
.
inputs
)):
for
outer_iidx
in
range
(
len
(
node
.
inputs
)):
inner_iidxs
=
self
.
get_inner_iidx_from_outer_iidx
(
outer_iidx
)
inner_iidxs
=
self
.
var_mappings
[
'inner_inp_from_outer_inp'
][
outer_iidx
]
for
inner_oidx
in
inner_oidxs
:
for
inner_oidx
in
inner_oidxs
:
for
inner_iidx
in
inner_iidxs
:
for
inner_iidx
in
inner_iidxs
:
...
@@ -1636,7 +1584,6 @@ class Scan(PureOp):
...
@@ -1636,7 +1584,6 @@ class Scan(PureOp):
# input to `z_t` then `x` is an input to `z_t`.
# input to `z_t` then `x` is an input to `z_t`.
n_outs
=
len
(
node
.
outputs
)
n_outs
=
len
(
node
.
outputs
)
outer_iidx_from_outer_oidx
=
self
.
get_outer_iidx_from_outer_oidx_seq
()
for
steps
in
xrange
(
n_outs
):
for
steps
in
xrange
(
n_outs
):
for
iidx
in
xrange
(
n_outs
):
for
iidx
in
xrange
(
n_outs
):
...
@@ -1644,7 +1591,7 @@ class Scan(PureOp):
...
@@ -1644,7 +1591,7 @@ class Scan(PureOp):
# Get the idx of the outer input corresponding to that
# Get the idx of the outer input corresponding to that
# outer output
# outer output
j_inp_idx
=
outer_iidx_from_outer_oidx
[
jidx
]
j_inp_idx
=
self
.
var_mappings
[
"outer_inp_from_outer_out"
]
[
jidx
]
if
j_inp_idx
!=
-
1
:
if
j_inp_idx
!=
-
1
:
if
connection_pattern
[
j_inp_idx
][
iidx
]
==
True
:
if
connection_pattern
[
j_inp_idx
][
iidx
]
==
True
:
...
@@ -1655,100 +1602,160 @@ class Scan(PureOp):
...
@@ -1655,100 +1602,160 @@ class Scan(PureOp):
node
.
tag
.
connection_pattern
=
connection_pattern
node
.
tag
.
connection_pattern
=
connection_pattern
return
connection_pattern
return
connection_pattern
def
get_inner_oidx_from_outer_oidx
(
self
,
outer_oidx
):
def
get_oinp_iinp_iout_oout_mappings
(
self
):
"""Given the index of an outer output, return the indices of the
""" Compute and return dictionary mappings between the inputs and
corresponding inner output(s) in a sequence.
outputs of the inner function and the inputs and outputs of the Scan
"""
node in the outer graph.
s
=
0
e
=
0
The return value is a dictionary in which the keys are the names of
for
p
in
xrange
(
outer_oidx
+
1
):
the individual mappings and the values are the mapping dictionaries
s
=
e
themselves. In dictionaries representing mappings to outer variables,
if
p
<
self
.
n_mit_mot
:
the values are individual integer indices. In dictionaries
e
+=
len
(
self
.
mitmot_out_taps
()[
p
])
representing mappings to inner variables, the values are sequences of
else
:
indices because multiple inner variables can be associated with the
e
+=
1
same state
return
range
(
s
,
e
)
def
get_inner_iidx_from_outer_iidx
(
self
,
outer_oidx
):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
"""
"""
outer_iidx_from_inner_iidx
=
self
.
get_outer_iidx_from_inner_iidx_seq
()
# For every inner input, if the corresponding outer input is the
# Lists for outer variables contain individual indices, lists for
# desired one, store the index
# inner variables contain sequences of indices because many inner
inner_iidxs
=
[]
# variables can be associated with the same outer variable. The list
for
i
in
xrange
(
len
(
outer_iidx_from_inner_iidx
)):
# and indices are initialized already containing the data associated
if
outer_iidx_from_inner_iidx
[
i
]
==
outer_oidx
:
# with the timestep index, the first outer input.
inner_iidxs
.
append
(
i
)
outer_input_indices
=
[
0
]
inner_input_indices
=
[[]]
return
inner_iidxs
inner_output_indices
=
[[]]
outer_output_indices
=
[
-
1
]
def
get_outer_iidx_from_outer_oidx_seq
(
self
):
""" Return a sequence where the value at the i-th position is the
outer_iidx
=
1
index of the outer input corresponding to the i-th outer output
inner_iidx
=
0
inner_oidx
=
0
NOTE: mitmots, mitsots, sitsots and shared outputs have corresponding
outer_oidx
=
0
outer inputs but not nitsots.
"""
nb_outer_outputs
=
(
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
result
=
[
-
1
]
*
nb_outer_outputs
# Process mitmots, mitsots and sitsots
input_offset
=
1
+
self
.
n_seqs
output_offset
=
0
for
i
in
range
(
len
(
self
.
tap_array
)):
result
[
output_offset
]
=
input_offset
input_offset
+=
1
output_offset
+=
1
# Process shared inputs/outputs
output_offset
+=
self
.
n_nit_sot
for
i
in
range
(
self
.
n_shared_outs
):
result
[
output_offset
]
=
input_offset
input_offset
+=
1
output_offset
+=
1
return
result
def
get_outer_iidx_from_inner_iidx_seq
(
self
):
""" Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th inner input
"""
output
=
[]
outer_inp_idx
=
1
# First outer input is timestep index, skip it
# Handle sequences inputs
# Handle sequences inputs
for
i
in
range
(
self
.
info
[
'n_seqs'
]):
for
i
in
range
(
self
.
info
[
'n_seqs'
]):
output
.
append
(
outer_inp_idx
)
outer_input_indices
.
append
(
outer_iidx
)
outer_inp_idx
+=
1
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([])
# Handle mitmots, mitsots and sitsots inputs
outer_output_indices
.
append
(
-
1
)
for
input_taps
in
self
.
info
[
'tap_array'
]:
for
tap
in
input_taps
:
outer_iidx
+=
1
output
.
append
(
outer_inp_idx
)
inner_iidx
+=
1
outer_inp_idx
+=
1
inner_oidx
+=
0
outer_oidx
+=
0
# Handle shared inputs
# Handle mitmots, mitsots and sitsots variables
for
i
in
range
(
len
(
self
.
info
[
'tap_array'
])):
nb_input_taps
=
len
(
self
.
info
[
'tap_array'
][
i
])
if
i
<
self
.
n_mit_mot
:
nb_output_taps
=
len
(
self
.
mit_mot_out_slices
[
i
])
else
:
nb_output_taps
=
1
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
(
range
(
inner_iidx
,
inner_iidx
+
nb_input_taps
))
inner_output_indices
.
append
(
range
(
inner_oidx
,
inner_oidx
+
nb_output_taps
))
outer_output_indices
.
append
(
outer_oidx
)
outer_iidx
+=
1
inner_iidx
+=
nb_input_taps
inner_oidx
+=
nb_output_taps
outer_oidx
+=
1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx
+=
self
.
info
[
'n_shared_outs'
]
# Handle nitsots variables
for
i
in
range
(
self
.
n_nit_sot
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([])
inner_output_indices
.
append
([
inner_oidx
])
outer_output_indices
.
append
(
outer_oidx
)
outer_iidx
+=
1
inner_iidx
+=
0
inner_oidx
+=
1
outer_oidx
+=
1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx
-=
(
self
.
info
[
'n_shared_outs'
]
+
self
.
n_nit_sot
)
# Handle shared states
for
i
in
range
(
self
.
info
[
'n_shared_outs'
]):
for
i
in
range
(
self
.
info
[
'n_shared_outs'
]):
output
.
append
(
outer_inp_idx
)
outer_input_indices
.
append
(
outer_iidx
)
outer_inp_idx
+=
1
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([
inner_oidx
])
# No inner input corresponds to the outer nitsot inputs but they still
outer_output_indices
.
append
(
outer_oidx
)
# need to be counted
outer_inp_idx
+=
self
.
info
[
'n_nit_sot'
]
outer_iidx
+=
1
inner_iidx
+=
1
# Handle non-sequences inputs
inner_oidx
+=
1
nb_nonseqs_inputs
=
len
(
self
.
inputs
)
-
len
(
output
)
outer_oidx
+=
1
for
i
in
range
(
nb_nonseqs_inputs
):
output
.
append
(
outer_inp_idx
)
# This is needed because, for outer inputs (and for outer inputs only)
outer_inp_idx
+=
1
# nitsots come *after* shared variables.
outer_iidx
+=
self
.
n_nit_sot
return
output
# Handle non-sequence inputs
# Note : the number of non-sequence inputs is not stored in self.info
# so it has to be inferred from the number of inner inputs that remain
# to be handled
for
i
in
range
(
len
(
self
.
inputs
)
-
inner_iidx
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([])
outer_output_indices
.
append
(
-
1
)
outer_iidx
+=
1
inner_iidx
+=
1
inner_oidx
+=
0
outer_oidx
+=
0
# With the global mapping inferred, the individual mappings
# can be produced
mappings
=
{
"outer_inp_from_outer_out"
:
{},
"inner_inp_from_outer_out"
:
{},
"inner_out_from_outer_out"
:
{},
"inner_inp_from_outer_inp"
:
{},
"inner_out_from_outer_inp"
:
{},
"outer_out_from_outer_inp"
:
{},
"outer_inp_from_inner_inp"
:
{},
"inner_out_from_inner_inp"
:
{},
"outer_out_from_inner_inp"
:
{},
"outer_inp_from_inner_out"
:
{},
"inner_inp_from_inner_out"
:
{},
"outer_out_from_inner_out"
:
{}}
for
(
oinp
,
iinp
,
iout
,
oout
)
in
zip
(
outer_input_indices
,
inner_input_indices
,
inner_output_indices
,
outer_output_indices
):
if
oout
!=
-
1
:
mappings
[
"outer_inp_from_outer_out"
][
oout
]
=
oinp
mappings
[
"inner_inp_from_outer_out"
][
oout
]
=
iinp
mappings
[
"inner_out_from_outer_out"
][
oout
]
=
iout
if
oinp
!=
-
1
:
mappings
[
"inner_inp_from_outer_inp"
][
oinp
]
=
iinp
mappings
[
"inner_out_from_outer_inp"
][
oinp
]
=
iout
mappings
[
"outer_out_from_outer_inp"
][
oinp
]
=
oout
for
idx
in
iinp
:
mappings
[
"outer_inp_from_inner_inp"
][
idx
]
=
oinp
mappings
[
"inner_out_from_inner_inp"
][
idx
]
=
iout
mappings
[
"outer_out_from_inner_inp"
][
idx
]
=
oout
for
idx
in
iout
:
mappings
[
"outer_inp_from_inner_out"
][
idx
]
=
oinp
mappings
[
"inner_inp_from_inner_out"
][
idx
]
=
iinp
mappings
[
"outer_out_from_inner_out"
][
idx
]
=
oout
return
mappings
# GRAD FUNCTION
# GRAD FUNCTION
def
grad
(
self
,
inputs
,
dC_douts
):
def
grad
(
self
,
inputs
,
dC_douts
):
...
@@ -1896,10 +1903,14 @@ class Scan(PureOp):
...
@@ -1896,10 +1903,14 @@ class Scan(PureOp):
for
pos
,
inp
in
enumerate
(
states
):
for
pos
,
inp
in
enumerate
(
states
):
if
inp
in
theano
.
gof
.
graph
.
inputs
([
Xt
]):
if
inp
in
theano
.
gof
.
graph
.
inputs
([
Xt
]):
oidx
=
self
.
get_output_pos
(
pos
)
# Get the index of the outer output that to which
if
not
isinstance
(
dC_douts
[
oidx
]
.
type
,
# the state variable 'inp' corresponds.
outer_oidx
=
self
.
var_mappings
[
'outer_out_from_inner_inp'
][
self
.
n_seqs
+
pos
]
if
not
isinstance
(
dC_douts
[
outer_oidx
]
.
type
,
DisconnectedType
):
DisconnectedType
):
dtypes
.
append
(
dC_douts
[
oidx
]
.
dtype
)
dtypes
.
append
(
dC_douts
[
o
uter_o
idx
]
.
dtype
)
if
dtypes
:
if
dtypes
:
new_dtype
=
theano
.
scalar
.
upcast
(
*
dtypes
)
new_dtype
=
theano
.
scalar
.
upcast
(
*
dtypes
)
else
:
else
:
...
@@ -1943,14 +1954,25 @@ class Scan(PureOp):
...
@@ -1943,14 +1954,25 @@ class Scan(PureOp):
# construct dX_dtm1
# construct dX_dtm1
dC_dXtm1s
=
[]
dC_dXtm1s
=
[]
for
pos
,
x
in
enumerate
(
dC_dinps_t
[
self
.
n_seqs
:]):
for
pos
,
x
in
enumerate
(
dC_dinps_t
[
self
.
n_seqs
:]):
opos
=
self
.
get_output_pos
(
pos
)
if
opos
>=
0
:
# Get the index of the first inner input corresponding to the
# pos-ieth inner input state
idxs
=
self
.
var_mappings
[
'inner_out_from_inner_inp'
][
self
.
n_seqs
+
pos
]
# Check if the pos-th input is associated with one of the
# recurrent states
x_is_state
=
pos
<
sum
([
len
(
t
)
for
t
in
self
.
tap_array
])
if
x_is_state
and
len
(
idxs
)
>
0
:
opos
=
idxs
[
0
]
dC_dXtm1s
.
append
(
safe_new
(
dC_dXts
[
opos
]))
dC_dXtm1s
.
append
(
safe_new
(
dC_dXts
[
opos
]))
if
hasattr
(
x
,
'dtype'
)
and
x
.
dtype
!=
dC_dXts
[
opos
]
.
dtype
:
if
hasattr
(
x
,
'dtype'
)
and
x
.
dtype
!=
dC_dXts
[
opos
]
.
dtype
:
dC_dinps_t
[
pos
+
self
.
n_seqs
]
=
\
dC_dinps_t
[
pos
+
self
.
n_seqs
]
=
\
x
.
astype
(
dC_dXts
[
opos
]
.
dtype
)
x
.
astype
(
dC_dXts
[
opos
]
.
dtype
)
else
:
else
:
dC_dXtm1s
.
append
(
safe_new
(
x
))
dC_dXtm1s
.
append
(
safe_new
(
x
))
for
dx
,
dC_dXtm1
in
enumerate
(
dC_dXtm1s
):
for
dx
,
dC_dXtm1
in
enumerate
(
dC_dXtm1s
):
if
isinstance
(
dC_dinps_t
[
dx
+
self
.
n_seqs
]
.
type
,
NullType
):
if
isinstance
(
dC_dinps_t
[
dx
+
self
.
n_seqs
]
.
type
,
NullType
):
# The accumulated gradient is undefined
# The accumulated gradient is undefined
...
...
theano/scan_module/tests/test_scan.py
浏览文件 @
6c3d8e26
...
@@ -657,19 +657,18 @@ class T_Scan(unittest.TestCase):
...
@@ -657,19 +657,18 @@ class T_Scan(unittest.TestCase):
tensor
.
grad
(
a
[
-
1
],
a0
)
tensor
.
grad
(
a
[
-
1
],
a0
)
# Also validate that the m
ethods get_outer_iidx_from_outer_oidx_seq
# Also validate that the m
appings outer_inp_from_outer_out and
#
and get_outer_iidx_from_inner_iidx_seq
produce the correct results
#
outer_inp_from_inner_inp
produce the correct results
scan_node
=
a
.
owner
.
inputs
[
0
]
.
owner
scan_node
=
a
.
owner
.
inputs
[
0
]
.
owner
result
=
scan_node
.
op
.
get_outer_iidx_from_outer_oidx_seq
()
result
=
scan_node
.
op
.
var_mappings
[
'outer_inp_from_outer_out'
]
expected_result
=
[
1
,
2
]
expected_result
=
{
0
:
1
,
1
:
2
}
assert
(
result
==
expected_result
)
assert
(
result
==
expected_result
)
result
=
scan_node
.
op
.
get_outer_iidx_from_inner_iidx_seq
()
result
=
scan_node
.
op
.
var_mappings
[
'outer_inp_from_inner_inp'
]
expected_result
=
[
1
,
1
,
2
,
2
]
expected_result
=
{
0
:
1
,
1
:
1
,
2
:
2
,
3
:
2
}
assert
(
result
==
expected_result
)
assert
(
result
==
expected_result
)
def
test_connection_pattern2
(
self
):
def
test_connection_pattern2
(
self
):
# This tests for a crash in connection_pattern() when a scan node
# This tests for a crash in connection_pattern() when a scan node
# has more than one mitmot (multiple input taps as well as
# has more than one mitmot (multiple input taps as well as
...
@@ -690,18 +689,42 @@ class T_Scan(unittest.TestCase):
...
@@ -690,18 +689,42 @@ class T_Scan(unittest.TestCase):
scan_node
=
g_out
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
scan_node
=
g_out
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
connection_pattern
=
scan_node
.
op
.
connection_pattern
(
scan_node
)
connection_pattern
=
scan_node
.
op
.
connection_pattern
(
scan_node
)
# Also validate that the m
ethods get_outer_iidx_from_outer_oidx_seq
# Also validate that the m
appings outer_inp_from_outer_out and
#
and get_outer_iidx_from_inner_iidx_seq
produce the correct results
#
outer_inp_from_inner_inp
produce the correct results
scan_node
=
out
.
owner
.
inputs
[
0
]
.
owner
scan_node
=
out
.
owner
.
inputs
[
0
]
.
owner
result
=
scan_node
.
op
.
get_outer_iidx_from_outer_oidx_seq
()
result
=
scan_node
.
op
.
var_mappings
[
'outer_inp_from_outer_out'
]
expected_result
=
[
2
]
expected_result
=
{
0
:
2
}
assert
(
result
==
expected_result
)
assert
(
result
==
expected_result
)
result
=
scan_node
.
op
.
get_outer_iidx_from_inner_iidx_seq
()
result
=
scan_node
.
op
.
var_mappings
[
'outer_inp_from_inner_inp'
]
expected_result
=
[
1
,
2
,
2
]
expected_result
=
{
0
:
1
,
1
:
2
,
2
:
2
}
assert
(
result
==
expected_result
)
assert
(
result
==
expected_result
)
def
test_grad_grad_mitsot_sitsot
(
self
):
# Test for an index error when taking the second derivative
# through a Scan node with one sitsot and one mitsot.
def
inner_fct
(
mitsot_m2
,
mitsot_m1
,
sitsot
):
total
=
mitsot_m2
+
mitsot_m1
+
sitsot
output
=
total
**
2
return
output
,
output
inputs
=
[
tensor
.
matrix
(),
tensor
.
vector
()]
outputs_info
=
[
dict
(
initial
=
inputs
[
0
],
taps
=
[
-
2
,
-
1
]),
inputs
[
1
]]
scan_outputs
,
updates
=
theano
.
scan
(
fn
=
inner_fct
,
outputs_info
=
outputs_info
,
n_steps
=
5
)
# Take the gradient of each output wrt its corresponding initial state
gradients
=
[
theano
.
grad
(
scan_outputs
[
0
]
.
sum
(),
inputs
[
0
]),
theano
.
grad
(
scan_outputs
[
1
]
.
sum
(),
inputs
[
1
])]
# Take the gradient of the sum of gradients wrt the inputs
sum_of_grads
=
sum
([
g
.
sum
()
for
g
in
gradients
])
second_gradients
=
theano
.
grad
(
sum_of_grads
,
inputs
[
0
])
def
test_grad_two_scans
(
self
):
def
test_grad_two_scans
(
self
):
# data input & output
# data input & output
...
@@ -1680,16 +1703,16 @@ class T_Scan(unittest.TestCase):
...
@@ -1680,16 +1703,16 @@ class T_Scan(unittest.TestCase):
analytic_grad
[
max_err_pos
],
analytic_grad
[
max_err_pos
],
num_grad
.
gx
[
max_err_pos
]))
num_grad
.
gx
[
max_err_pos
]))
# Also validate that the m
ethods get_outer_iidx_from_outer_oidx_seq
# Also validate that the m
appings outer_inp_from_outer_out and
#
and get_outer_iidx_from_inner_iidx_seq
produce the correct results
#
outer_inp_from_inner_inp
produce the correct results
scan_node
=
updates
.
values
()[
0
]
.
owner
scan_node
=
updates
.
values
()[
0
]
.
owner
result
=
scan_node
.
op
.
get_outer_iidx_from_outer_oidx_seq
()
result
=
scan_node
.
op
.
var_mappings
[
'outer_inp_from_outer_out'
]
expected_result
=
[
3
,
-
1
,
4
]
expected_result
=
{
0
:
3
,
1
:
5
,
2
:
4
}
assert
(
result
==
expected_result
)
assert
(
result
==
expected_result
)
result
=
scan_node
.
op
.
get_outer_iidx_from_inner_iidx_seq
()
result
=
scan_node
.
op
.
var_mappings
[
'outer_inp_from_inner_inp'
]
expected_result
=
[
1
,
2
,
3
,
4
,
6
]
expected_result
=
{
0
:
1
,
1
:
2
,
2
:
3
,
3
:
4
,
4
:
6
}
assert
(
result
==
expected_result
)
assert
(
result
==
expected_result
)
def
test_grad_multiple_outs_some_truncate
(
self
):
def
test_grad_multiple_outs_some_truncate
(
self
):
...
@@ -3299,6 +3322,69 @@ class T_Scan(unittest.TestCase):
...
@@ -3299,6 +3322,69 @@ class T_Scan(unittest.TestCase):
if
isinstance
(
x
.
op
,
theano
.
scan_module
.
scan_op
.
Scan
)]
if
isinstance
(
x
.
op
,
theano
.
scan_module
.
scan_op
.
Scan
)]
assert
len
(
lssc
)
==
0
assert
len
(
lssc
)
==
0
def
test_oinp_iinp_iout_oout_mappings
(
self
):
# Test the mapping produces by
# ScanOp.get_oinp_iinp_iout_oout_mappings()
rng
=
theano
.
tensor
.
shared_randomstreams
.
RandomStreams
(
123
)
def
inner_fct
(
seq
,
mitsot
,
sitsot
,
nitsot
,
nseq
):
random_scalar
=
rng
.
uniform
((
1
,))[
0
]
total
=
seq
+
mitsot
+
sitsot
+
nitsot
+
nseq
+
random_scalar
return
total
,
total
,
total
# Assemble a scan with one sequence, one mitsot, one sitsot, one nitsot
# a non-sequence and a random state to test the mappings.
seq
=
[
tensor
.
vector
()]
non_seq
=
[
tensor
.
scalar
()]
outputs_info
=
[
dict
(
initial
=
tensor
.
vector
(),
taps
=
[
-
3
,
-
1
]),
tensor
.
scalar
(),
None
]
scan_outputs
,
_
=
theano
.
scan
(
fn
=
inner_fct
,
sequences
=
seq
,
outputs_info
=
outputs_info
,
non_sequences
=
non_seq
)
# Compare the mappings with the expected values
scan_node
=
scan_outputs
[
0
]
.
owner
.
inputs
[
0
]
.
owner
mappings
=
scan_node
.
op
.
var_mappings
assert
mappings
[
'inner_inp_from_outer_inp'
]
==
{
0
:
[],
1
:
[
0
],
2
:
[
1
,
2
],
3
:
[
3
],
4
:
[
4
],
5
:
[],
6
:
[
5
]}
assert
mappings
[
'inner_out_from_outer_inp'
]
==
{
0
:
[],
1
:
[],
2
:
[
0
],
3
:
[
1
],
4
:
[
3
],
5
:
[
2
],
6
:
[]}
assert
mappings
[
'outer_out_from_outer_inp'
]
==
{
0
:
-
1
,
1
:
-
1
,
2
:
0
,
3
:
1
,
4
:
3
,
5
:
2
,
6
:
-
1
}
assert
mappings
[
'outer_inp_from_inner_inp'
]
==
{
0
:
1
,
1
:
2
,
2
:
2
,
3
:
3
,
4
:
4
,
5
:
6
}
assert
mappings
[
'inner_out_from_inner_inp'
]
==
{
0
:
[],
1
:
[
0
],
2
:
[
0
],
3
:
[
1
],
4
:
[
3
],
5
:
[]}
assert
mappings
[
'outer_out_from_inner_inp'
]
==
{
0
:
-
1
,
1
:
0
,
2
:
0
,
3
:
1
,
4
:
3
,
5
:
-
1
}
assert
mappings
[
'outer_inp_from_inner_out'
]
==
{
0
:
2
,
1
:
3
,
2
:
5
,
3
:
4
}
assert
mappings
[
'inner_inp_from_inner_out'
]
==
{
0
:
[
1
,
2
],
1
:
[
3
],
2
:
[],
3
:
[
4
]}
assert
mappings
[
'outer_out_from_inner_out'
]
==
{
0
:
0
,
1
:
1
,
2
:
2
,
3
:
3
}
assert
mappings
[
'outer_inp_from_outer_out'
]
==
{
0
:
2
,
1
:
3
,
2
:
5
,
3
:
4
}
assert
mappings
[
'inner_inp_from_outer_out'
]
==
{
0
:
[
1
,
2
],
1
:
[
3
],
2
:
[],
3
:
[
4
]}
assert
mappings
[
'inner_out_from_outer_out'
]
==
{
0
:
[
0
],
1
:
[
1
],
2
:
[
2
],
3
:
[
3
]}
def
test_grad_duplicate_outputs
(
self
):
def
test_grad_duplicate_outputs
(
self
):
# This test validates that taking the gradient of a scan, in which
# This test validates that taking the gradient of a scan, in which
# multiple outputs are the same theano variable, works.
# multiple outputs are the same theano variable, works.
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论