Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
9a2280b8
提交
9a2280b8
authored
8月 08, 2021
作者:
Brandon T. Willard
提交者:
Brandon T. Willard
9月 15, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move Scan helper methods to ScanMethodsMixin
上级
e85c7fd0
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
455 行增加
和
454 行删除
+455
-454
printing.py
aesara/printing.py
+3
-1
op.py
aesara/scan/op.py
+439
-442
utils.py
aesara/scan/utils.py
+2
-2
scan.rst
doc/extending/scan.rst
+1
-1
test_basic.py
tests/scan/test_basic.py
+10
-7
test_utils.py
tests/scan/test_utils.py
+0
-1
没有找到文件。
aesara/printing.py
浏览文件 @
9a2280b8
...
@@ -237,7 +237,9 @@ N.B.:
...
@@ -237,7 +237,9 @@ N.B.:
outer_inputs
=
s
.
owner
.
inputs
outer_inputs
=
s
.
owner
.
inputs
inner_to_outer_inputs
=
{
inner_to_outer_inputs
=
{
inner_inputs
[
i
]:
outer_inputs
[
o
]
inner_inputs
[
i
]:
outer_inputs
[
o
]
for
i
,
o
in
s
.
owner
.
op
.
var_mappings
[
"outer_inp_from_inner_inp"
]
.
items
()
for
i
,
o
in
s
.
owner
.
op
.
get_oinp_iinp_iout_oout_mappings
()[
"outer_inp_from_inner_inp"
]
.
items
()
}
}
print
(
""
,
file
=
_file
)
print
(
""
,
file
=
_file
)
...
...
aesara/scan/op.py
浏览文件 @
9a2280b8
...
@@ -112,7 +112,383 @@ class ScanInfo:
...
@@ -112,7 +112,383 @@ class ScanInfo:
TensorConstructorType
=
Callable
[[
List
[
bool
],
Union
[
str
,
np
.
generic
]],
TensorType
]
TensorConstructorType
=
Callable
[[
List
[
bool
],
Union
[
str
,
np
.
generic
]],
TensorType
]
class
Scan
(
Op
):
class
ScanMethodsMixin
:
def
inner_seqs
(
self
,
list_inputs
):
# Given the list of inner inputs this function grabs those
# corresponding to sequences
return
list_inputs
[:
self
.
n_seqs
]
def
outer_seqs
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
# Given the list of outer inputs this function grabs those
# corresponding to sequences
return
list_inputs
[
1
:
1
+
self
.
n_seqs
]
def
inner_mitmot
(
self
,
list_inputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
return
list_inputs
[
self
.
n_seqs
:
self
.
n_seqs
+
n_taps
]
def
outer_mitmot
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
return
list_inputs
[
1
+
self
.
n_seqs
:
1
+
self
.
n_seqs
+
self
.
n_mit_mot
]
def
inner_mitmot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
return
list_outputs
[:
n_taps
]
def
outer_mitmot_outs
(
self
,
list_outputs
):
if
isinstance
(
list_outputs
,
Apply
):
list_outputs
=
list_outputs
.
outputs
return
list_outputs
[:
self
.
n_mit_mot
]
def
mitmot_taps
(
self
):
return
self
.
tap_array
[:
self
.
n_mit_mot
]
def
mitmot_out_taps
(
self
):
return
self
.
mit_mot_out_slices
[:
self
.
n_mit_mot
]
def
inner_mitsot
(
self
,
list_inputs
):
n_mitmot_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
ntaps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
(
self
.
n_mit_mot
+
self
.
n_mit_sot
)]
)
return
list_inputs
[
self
.
n_seqs
+
n_mitmot_taps
:
self
.
n_seqs
+
ntaps_upto_sit_sot
]
def
outer_mitsot
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
return
list_inputs
[
offset
:
offset
+
self
.
n_mit_sot
]
def
inner_mitsot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
return
list_outputs
[
n_taps
:
n_taps
+
self
.
n_mit_sot
]
def
outer_mitsot_outs
(
self
,
list_outputs
):
if
isinstance
(
list_outputs
,
Apply
):
list_outputs
=
list_outputs
.
outputs
return
list_outputs
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
def
mitsot_taps
(
self
):
return
self
.
tap_array
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
def
inner_sitsot
(
self
,
list_inputs
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
(
self
.
n_mit_mot
+
self
.
n_mit_sot
)]
)
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
return
list_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_sitsot
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
return
list_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
inner_sitsot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
return
list_outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_sitsot_outs
(
self
,
list_outputs
):
if
isinstance
(
list_outputs
,
Apply
):
list_outputs
=
list_outputs
.
outputs
offset
=
self
.
n_mit_mot
+
self
.
n_mit_sot
return
list_outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_nitsot
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
)
return
list_inputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
inner_nitsot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
return
list_outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
outer_nitsot_outs
(
self
,
list_outputs
):
if
isinstance
(
list_outputs
,
Apply
):
list_outputs
=
list_outputs
.
outputs
offset
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
return
list_outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
inner_shared
(
self
,
list_inputs
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
(
self
.
n_mit_mot
+
self
.
n_mit_sot
)]
)
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
n_sit_sot
return
list_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
outer_shared
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
return
list_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
inner_shared_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
+
self
.
n_nit_sot
return
list_outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
outer_shared_outs
(
self
,
list_outputs
):
if
isinstance
(
list_outputs
,
Apply
):
list_outputs
=
list_outputs
.
outputs
offset
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
return
list_outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
inner_non_seqs
(
self
,
list_inputs
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
(
self
.
n_mit_mot
+
self
.
n_mit_sot
)]
)
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
return
list_inputs
[
offset
:]
def
outer_non_seqs
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
return
list_inputs
[
offset
:]
def
get_oinp_iinp_iout_oout_mappings
(
self
):
"""
Compute and return dictionary mappings between the inputs and
outputs of the inner function and the inputs and outputs of the Scan
node in the outer graph.
The return value is a dictionary in which the keys are the names of
the individual mappings and the values are the mapping dictionaries
themselves. In dictionaries representing mappings to outer variables,
the values are individual integer indices. In dictionaries
representing mappings to inner variables, the values are sequences of
indices because multiple inner variables can be associated with the
same state.
"""
# Lists for outer variables contain individual indices, lists for
# inner variables contain sequences of indices because many inner
# variables can be associated with the same outer variable. The list
# and indices are initialized already containing the data associated
# with the timestep index, the first outer input.
outer_input_indices
=
[
0
]
inner_input_indices
=
[[]]
inner_output_indices
=
[[]]
outer_output_indices
=
[
-
1
]
outer_iidx
=
1
inner_iidx
=
0
inner_oidx
=
0
outer_oidx
=
0
# Handle sequences inputs
for
i
in
range
(
self
.
info
.
n_seqs
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([])
outer_output_indices
.
append
(
-
1
)
outer_iidx
+=
1
inner_iidx
+=
1
inner_oidx
+=
0
outer_oidx
+=
0
# Handle mitmots, mitsots and sitsots variables
for
i
in
range
(
len
(
self
.
info
.
tap_array
)):
nb_input_taps
=
len
(
self
.
info
.
tap_array
[
i
])
if
i
<
self
.
n_mit_mot
:
nb_output_taps
=
len
(
self
.
mit_mot_out_slices
[
i
])
else
:
nb_output_taps
=
1
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
(
list
(
range
(
inner_iidx
,
inner_iidx
+
nb_input_taps
))
)
inner_output_indices
.
append
(
list
(
range
(
inner_oidx
,
inner_oidx
+
nb_output_taps
))
)
outer_output_indices
.
append
(
outer_oidx
)
outer_iidx
+=
1
inner_iidx
+=
nb_input_taps
inner_oidx
+=
nb_output_taps
outer_oidx
+=
1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx
+=
self
.
info
.
n_shared_outs
# Handle nitsots variables
for
i
in
range
(
self
.
n_nit_sot
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([])
inner_output_indices
.
append
([
inner_oidx
])
outer_output_indices
.
append
(
outer_oidx
)
outer_iidx
+=
1
inner_iidx
+=
0
inner_oidx
+=
1
outer_oidx
+=
1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx
-=
self
.
info
.
n_shared_outs
+
self
.
n_nit_sot
# Handle shared states
for
i
in
range
(
self
.
info
.
n_shared_outs
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([
inner_oidx
])
outer_output_indices
.
append
(
outer_oidx
)
outer_iidx
+=
1
inner_iidx
+=
1
inner_oidx
+=
1
outer_oidx
+=
1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx
+=
self
.
n_nit_sot
# Handle non-sequence inputs
# Note : the number of non-sequence inputs is not stored in self.info
# so it has to be inferred from the number of inner inputs that remain
# to be handled
for
i
in
range
(
len
(
self
.
inputs
)
-
inner_iidx
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([])
outer_output_indices
.
append
(
-
1
)
outer_iidx
+=
1
inner_iidx
+=
1
inner_oidx
+=
0
outer_oidx
+=
0
# With the global mapping inferred, the individual mappings
# can be produced
mappings
=
{
"outer_inp_from_outer_out"
:
{},
"inner_inp_from_outer_out"
:
{},
"inner_out_from_outer_out"
:
{},
"inner_inp_from_outer_inp"
:
{},
"inner_out_from_outer_inp"
:
{},
"outer_out_from_outer_inp"
:
{},
"outer_inp_from_inner_inp"
:
{},
"inner_out_from_inner_inp"
:
{},
"outer_out_from_inner_inp"
:
{},
"outer_inp_from_inner_out"
:
{},
"inner_inp_from_inner_out"
:
{},
"outer_out_from_inner_out"
:
{},
}
for
(
oinp
,
iinp
,
iout
,
oout
)
in
zip
(
outer_input_indices
,
inner_input_indices
,
inner_output_indices
,
outer_output_indices
,
):
if
oout
!=
-
1
:
mappings
[
"outer_inp_from_outer_out"
][
oout
]
=
oinp
mappings
[
"inner_inp_from_outer_out"
][
oout
]
=
iinp
mappings
[
"inner_out_from_outer_out"
][
oout
]
=
iout
if
oinp
!=
-
1
:
mappings
[
"inner_inp_from_outer_inp"
][
oinp
]
=
iinp
mappings
[
"inner_out_from_outer_inp"
][
oinp
]
=
iout
mappings
[
"outer_out_from_outer_inp"
][
oinp
]
=
oout
for
idx
in
iinp
:
mappings
[
"outer_inp_from_inner_inp"
][
idx
]
=
oinp
mappings
[
"inner_out_from_inner_inp"
][
idx
]
=
iout
mappings
[
"outer_out_from_inner_inp"
][
idx
]
=
oout
for
idx
in
iout
:
mappings
[
"outer_inp_from_inner_out"
][
idx
]
=
oinp
mappings
[
"inner_inp_from_inner_out"
][
idx
]
=
iinp
mappings
[
"outer_out_from_inner_out"
][
idx
]
=
oout
return
mappings
def
validate_inner_graph
(
self
):
"""
Perform some elementary validations on the inner graph to ensure
that it is coherent.
"""
# For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype
nb_recurr_outputs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
var_mappings
=
self
.
get_oinp_iinp_iout_oout_mappings
()
for
outer_oidx
in
range
(
nb_recurr_outputs
):
inner_iidxs
=
var_mappings
[
"inner_inp_from_outer_out"
][
outer_oidx
]
inner_oidxs
=
var_mappings
[
"inner_out_from_outer_out"
][
outer_oidx
]
for
(
inner_iidx
,
inner_oidx
)
in
itertools
.
product
(
inner_iidxs
,
inner_oidxs
):
type_input
=
self
.
inputs
[
inner_iidx
]
.
type
type_output
=
self
.
outputs
[
inner_oidx
]
.
type
if
type_input
!=
type_output
:
raise
TypeError
(
"Inconsistency in the inner graph of "
f
"scan '{self.name}' : an input and an output are "
"associated with the same recurrent state "
"and should have the same type but have "
f
"type '{type_input}' and '{type_output}' respectively."
)
# If scan has the flag 'gpua' set to false (meaning that is shouldn't
# use the gpuarray gpu backend ), ensure that is has no input and no
# output with type GpuArrayType
from
aesara.gpuarray
import
GpuArrayType
if
not
self
.
info
.
gpua
:
for
inp
in
self
.
inputs
:
if
isinstance
(
inp
.
type
,
GpuArrayType
):
raise
TypeError
(
"Inconsistency in the inner graph of "
f
"scan '{self.name}' : one of the inputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case"
)
for
out
in
self
.
outputs
:
if
isinstance
(
out
.
type
,
GpuArrayType
):
raise
TypeError
(
"Inconsistency in the inner graph of "
f
"scan '{self.name}' : one of the outputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case"
)
class
Scan
(
Op
,
ScanMethodsMixin
):
def
__init__
(
def
__init__
(
self
,
self
,
inputs
:
List
[
Variable
],
inputs
:
List
[
Variable
],
...
@@ -242,73 +618,9 @@ class Scan(Op):
...
@@ -242,73 +618,9 @@ class Scan(Op):
)
)
self
.
_hash_inner_graph
=
hash
(
self
.
_cmodule_key
)
self
.
_hash_inner_graph
=
hash
(
self
.
_cmodule_key
)
# Compute mappings between outer inputs, outer outputs, inner
# inputs and inner outputs to determine with variables are associated
# with the same states.
self
.
var_mappings
=
self
.
get_oinp_iinp_iout_oout_mappings
()
def
validate_inner_graph
(
self
):
"""
Perform some elementary validations on the inner graph to ensure
that it is coherent.
"""
# For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype
nb_recurr_outputs
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
for
outer_oidx
in
range
(
nb_recurr_outputs
):
inner_iidxs
=
self
.
var_mappings
[
"inner_inp_from_outer_out"
][
outer_oidx
]
inner_oidxs
=
self
.
var_mappings
[
"inner_out_from_outer_out"
][
outer_oidx
]
for
(
inner_iidx
,
inner_oidx
)
in
itertools
.
product
(
inner_iidxs
,
inner_oidxs
):
type_input
=
self
.
inputs
[
inner_iidx
]
.
type
type_output
=
self
.
outputs
[
inner_oidx
]
.
type
if
type_input
!=
type_output
:
raise
TypeError
(
"Inconsistency in the inner graph of "
f
"scan '{self.name}' : an input and an output are "
"associated with the same recurrent state "
"and should have the same type but have "
f
"type '{type_input}' and '{type_output}' respectively."
)
# If scan has the flag 'gpua' set to false (meaning that is shouldn't
# use the gpuarray gpu backend ), ensure that is has no input and no
# output with type GpuArrayType
from
aesara.gpuarray
import
GpuArrayType
if
not
self
.
info
.
gpua
:
for
inp
in
self
.
inputs
:
if
isinstance
(
inp
.
type
,
GpuArrayType
):
raise
TypeError
(
"Inconsistency in the inner graph of "
f
"scan '{self.name}' : one of the inputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case"
)
for
out
in
self
.
outputs
:
if
isinstance
(
out
.
type
,
GpuArrayType
):
raise
TypeError
(
"Inconsistency in the inner graph of "
f
"scan '{self.name}' : one of the outputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case"
)
def
__setstate__
(
self
,
d
):
def
__setstate__
(
self
,
d
):
self
.
__dict__
.
update
(
d
)
self
.
__dict__
.
update
(
d
)
if
not
hasattr
(
self
,
"var_mappings"
):
# Generate the mappings between inner and outer inputs and outputs
# if they haven't already been generated.
self
.
var_mappings
=
self
.
get_oinp_iinp_iout_oout_mappings
()
if
hasattr
(
self
,
"fn"
):
if
hasattr
(
self
,
"fn"
):
if
not
hasattr
(
self
,
"thunk_mit_mot_out_slices"
):
if
not
hasattr
(
self
,
"thunk_mit_mot_out_slices"
):
# The thunk has been compiled before mit_mot preallocation
# The thunk has been compiled before mit_mot preallocation
...
@@ -1010,222 +1322,66 @@ class Scan(Op):
...
@@ -1010,222 +1322,66 @@ class Scan(Op):
if
self
.
destroy_map
:
if
self
.
destroy_map
:
cython_destroy_map
=
[
cython_destroy_map
=
[
x
in
self
.
destroy_map
for
x
in
range
(
len
(
node
.
outputs
))
x
in
self
.
destroy_map
for
x
in
range
(
len
(
node
.
outputs
))
]
]
else
:
else
:
cython_destroy_map
=
[
0
for
x
in
range
(
len
(
node
.
outputs
))]
cython_destroy_map
=
[
0
for
x
in
range
(
len
(
node
.
outputs
))]
cython_destroy_map
=
np
.
asarray
(
cython_destroy_map
,
dtype
=
"int32"
)
cython_destroy_map
=
np
.
asarray
(
cython_destroy_map
,
dtype
=
"int32"
)
from
.
import
scan_perform_ext
from
.
import
scan_perform_ext
def
p
(
node
,
args
,
outs
):
return
scan_perform_ext
.
perform
(
self
.
n_shared_outs
,
self
.
n_mit_mot_outs
,
self
.
n_seqs
,
self
.
n_mit_mot
,
self
.
n_mit_sot
,
self
.
n_sit_sot
,
self
.
n_nit_sot
,
args
[
0
],
self
.
as_while
,
cython_mintaps
,
cython_tap_array
,
cython_tap_array_len
,
cython_vector_seqs
,
cython_vector_outs
,
cython_mit_mot_out_slices
,
cython_mit_mot_out_nslices
,
cython_mitmots_preallocated
,
cython_inps_is_tensor
,
cython_outs_is_tensor
,
self
.
fn
.
fn
,
self
.
fn
,
cython_destroy_map
,
args
,
outs
,
self
,
node
,
)
except
(
ImportError
,
MissingGXX
):
p
=
self
.
perform
# default arguments are stored in the closure of `rval`
# Big ugly hack since we can't get the real value of allow_gc
# for the englobing function.
allow_gc
=
config
.
allow_gc
and
not
self
.
allow_gc
def
rval
(
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
,
allow_gc
=
allow_gc
):
r
=
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
)
for
o
in
node
.
outputs
:
compute_map
[
o
][
0
]
=
True
if
allow_gc
:
self
.
fn
.
free
()
return
r
rval
.
inputs
=
node_input_storage
rval
.
outputs
=
node_output_storage
rval
.
perform
=
p
rval
.
lazy
=
False
return
rval
def
inner_seqs
(
self
,
list_inputs
):
# Given the list of inner inputs this function grabs those
# corresponding to sequences
return
list_inputs
[:
self
.
n_seqs
]
def
outer_seqs
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
# Given the list of outer inputs this function grabs those
# corresponding to sequences
return
list_inputs
[
1
:
1
+
self
.
n_seqs
]
def
inner_mitmot
(
self
,
list_inputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
return
list_inputs
[
self
.
n_seqs
:
self
.
n_seqs
+
n_taps
]
def
outer_mitmot
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
return
list_inputs
[
1
+
self
.
n_seqs
:
1
+
self
.
n_seqs
+
self
.
n_mit_mot
]
def
inner_mitmot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
return
list_outputs
[:
n_taps
]
def
outer_mitmot_outs
(
self
,
list_outputs
):
if
isinstance
(
list_outputs
,
Apply
):
list_outputs
=
list_outputs
.
outputs
return
list_outputs
[:
self
.
n_mit_mot
]
def
mitmot_taps
(
self
):
return
self
.
tap_array
[:
self
.
n_mit_mot
]
def
mitmot_out_taps
(
self
):
return
self
.
mit_mot_out_slices
[:
self
.
n_mit_mot
]
def
inner_mitsot
(
self
,
list_inputs
):
n_mitmot_taps
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
self
.
n_mit_mot
])
ntaps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
(
self
.
n_mit_mot
+
self
.
n_mit_sot
)]
)
return
list_inputs
[
self
.
n_seqs
+
n_mitmot_taps
:
self
.
n_seqs
+
ntaps_upto_sit_sot
]
def
outer_mitsot
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
return
list_inputs
[
offset
:
offset
+
self
.
n_mit_sot
]
def
inner_mitsot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
return
list_outputs
[
n_taps
:
n_taps
+
self
.
n_mit_sot
]
def
outer_mitsot_outs
(
self
,
list_outputs
):
if
isinstance
(
list_outputs
,
Apply
):
list_outputs
=
list_outputs
.
outputs
return
list_outputs
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
def
mitsot_taps
(
self
):
return
self
.
tap_array
[
self
.
n_mit_mot
:
self
.
n_mit_mot
+
self
.
n_mit_sot
]
def
inner_sitsot
(
self
,
list_inputs
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
self
.
tap_array
[:
(
self
.
n_mit_mot
+
self
.
n_mit_sot
)]
)
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
return
list_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_sitsot
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
return
list_inputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
inner_sitsot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
return
list_outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_sitsot_outs
(
self
,
list_outputs
):
if
isinstance
(
list_outputs
,
Apply
):
list_outputs
=
list_outputs
.
outputs
offset
=
self
.
n_mit_mot
+
self
.
n_mit_sot
return
list_outputs
[
offset
:
offset
+
self
.
n_sit_sot
]
def
outer_nitsot
(
self
,
list_inputs
):
if
isinstance
(
list_inputs
,
Apply
):
list_inputs
=
list_inputs
.
inputs
offset
=
(
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
)
return
list_inputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
inner_nitsot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
return
list_outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
outer_nitsot_outs
(
self
,
list_outputs
):
if
isinstance
(
list_outputs
,
Apply
):
list_outputs
=
list_outputs
.
outputs
offset
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
return
list_outputs
[
offset
:
offset
+
self
.
n_nit_sot
]
def
inner_shared
(
self
,
list_inputs
):
def
p
(
node
,
args
,
outs
):
n_taps_upto_sit_sot
=
sum
(
return
scan_perform_ext
.
perform
(
len
(
x
)
for
x
in
self
.
tap_array
[:
(
self
.
n_mit_mot
+
self
.
n_mit_sot
)]
self
.
n_shared_outs
,
)
self
.
n_mit_mot_outs
,
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
n_sit_sot
self
.
n_seqs
,
return
list_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
self
.
n_mit_mot
,
self
.
n_mit_sot
,
self
.
n_sit_sot
,
self
.
n_nit_sot
,
args
[
0
],
self
.
as_while
,
cython_mintaps
,
cython_tap_array
,
cython_tap_array_len
,
cython_vector_seqs
,
cython_vector_outs
,
cython_mit_mot_out_slices
,
cython_mit_mot_out_nslices
,
cython_mitmots_preallocated
,
cython_inps_is_tensor
,
cython_outs_is_tensor
,
self
.
fn
.
fn
,
self
.
fn
,
cython_destroy_map
,
args
,
outs
,
self
,
node
,
)
def
outer_shared
(
self
,
list_inputs
):
except
(
ImportError
,
MissingGXX
):
if
isinstance
(
list_inputs
,
Apply
):
p
=
self
.
perform
list_inputs
=
list_inputs
.
inputs
offset
=
1
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
return
list_inputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
inner_shared_outs
(
self
,
list_outputs
):
# default arguments are stored in the closure of `rval`
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_out_slices
)
offset
=
self
.
n_mit_sot
+
n_taps
+
self
.
n_sit_sot
+
self
.
n_nit_sot
return
list_outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
outer_shared_outs
(
self
,
list_outputs
):
# Big ugly hack since we can't get the real value of allow_gc
if
isinstance
(
list_outputs
,
Apply
):
# for the englobing function.
list_outputs
=
list_outputs
.
outputs
allow_gc
=
config
.
allow_gc
and
not
self
.
allow_gc
offset
=
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
return
list_outputs
[
offset
:
offset
+
self
.
n_shared_outs
]
def
inner_non_seqs
(
self
,
list_inputs
):
def
rval
(
n_taps_upto_sit_sot
=
sum
(
p
=
p
,
i
=
node_input_storage
,
o
=
node_output_storage
,
n
=
node
,
allow_gc
=
allow_gc
len
(
x
)
for
x
in
self
.
tap_array
[:
(
self
.
n_mit_mot
+
self
.
n_mit_sot
)]
):
)
r
=
p
(
n
,
[
x
[
0
]
for
x
in
i
],
o
)
offset
=
self
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
n_sit_sot
+
self
.
n_shared_outs
for
o
in
node
.
outputs
:
return
list_inputs
[
offset
:]
compute_map
[
o
][
0
]
=
True
if
allow_gc
:
self
.
fn
.
free
()
return
r
def
outer_non_seqs
(
self
,
list_inputs
):
rval
.
inputs
=
node_input_storage
if
isinstance
(
list_inputs
,
Apply
):
rval
.
outputs
=
node_output_storage
list_inputs
=
list_inputs
.
inputs
rval
.
perform
=
p
offset
=
(
rval
.
lazy
=
False
1
return
rval
+
self
.
n_seqs
+
self
.
n_mit_mot
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_shared_outs
)
return
list_inputs
[
offset
:]
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
def
perform
(
self
,
node
,
inputs
,
output_storage
,
params
=
None
):
"""Compute the scan operation in Python.
"""Compute the scan operation in Python.
...
@@ -1885,11 +2041,13 @@ class Scan(Op):
...
@@ -1885,11 +2041,13 @@ class Scan(Op):
# over every possible pairing of their corresponding inner inputs
# over every possible pairing of their corresponding inner inputs
# and inner outputs and, if one such pair of inner variables is
# and inner outputs and, if one such pair of inner variables is
# connected than the pair of outer variables is connected.
# connected than the pair of outer variables is connected.
var_mappings
=
self
.
get_oinp_iinp_iout_oout_mappings
()
for
outer_oidx
in
range
(
len
(
node
.
outputs
)):
for
outer_oidx
in
range
(
len
(
node
.
outputs
)):
inner_oidxs
=
self
.
var_mappings
[
"inner_out_from_outer_out"
][
outer_oidx
]
inner_oidxs
=
var_mappings
[
"inner_out_from_outer_out"
][
outer_oidx
]
for
outer_iidx
in
range
(
len
(
node
.
inputs
)):
for
outer_iidx
in
range
(
len
(
node
.
inputs
)):
inner_iidxs
=
self
.
var_mappings
[
"inner_inp_from_outer_inp"
][
outer_iidx
]
inner_iidxs
=
var_mappings
[
"inner_inp_from_outer_inp"
][
outer_iidx
]
for
inner_oidx
in
inner_oidxs
:
for
inner_oidx
in
inner_oidxs
:
for
inner_iidx
in
inner_iidxs
:
for
inner_iidx
in
inner_iidxs
:
...
@@ -1913,7 +2071,7 @@ class Scan(Op):
...
@@ -1913,7 +2071,7 @@ class Scan(Op):
# Get the idx of the outer input corresponding to that
# Get the idx of the outer input corresponding to that
# outer output
# outer output
j_inp_idx
=
self
.
var_mappings
[
"outer_inp_from_outer_out"
][
jidx
]
j_inp_idx
=
var_mappings
[
"outer_inp_from_outer_out"
][
jidx
]
if
j_inp_idx
!=
-
1
:
if
j_inp_idx
!=
-
1
:
if
connection_pattern
[
j_inp_idx
][
iidx
]
is
True
:
if
connection_pattern
[
j_inp_idx
][
iidx
]
is
True
:
...
@@ -1924,168 +2082,6 @@ class Scan(Op):
...
@@ -1924,168 +2082,6 @@ class Scan(Op):
node
.
tag
.
connection_pattern
=
connection_pattern
node
.
tag
.
connection_pattern
=
connection_pattern
return
connection_pattern
return
connection_pattern
def
get_oinp_iinp_iout_oout_mappings
(
self
):
"""
Compute and return dictionary mappings between the inputs and
outputs of the inner function and the inputs and outputs of the Scan
node in the outer graph.
The return value is a dictionary in which the keys are the names of
the individual mappings and the values are the mapping dictionaries
themselves. In dictionaries representing mappings to outer variables,
the values are individual integer indices. In dictionaries
representing mappings to inner variables, the values are sequences of
indices because multiple inner variables can be associated with the
same state.
"""
# Lists for outer variables contain individual indices, lists for
# inner variables contain sequences of indices because many inner
# variables can be associated with the same outer variable. The list
# and indices are initialized already containing the data associated
# with the timestep index, the first outer input.
outer_input_indices
=
[
0
]
inner_input_indices
=
[[]]
inner_output_indices
=
[[]]
outer_output_indices
=
[
-
1
]
outer_iidx
=
1
inner_iidx
=
0
inner_oidx
=
0
outer_oidx
=
0
# Handle sequences inputs
for
i
in
range
(
self
.
info
.
n_seqs
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([])
outer_output_indices
.
append
(
-
1
)
outer_iidx
+=
1
inner_iidx
+=
1
inner_oidx
+=
0
outer_oidx
+=
0
# Handle mitmots, mitsots and sitsots variables
for
i
in
range
(
len
(
self
.
info
.
tap_array
)):
nb_input_taps
=
len
(
self
.
info
.
tap_array
[
i
])
if
i
<
self
.
n_mit_mot
:
nb_output_taps
=
len
(
self
.
mit_mot_out_slices
[
i
])
else
:
nb_output_taps
=
1
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
(
list
(
range
(
inner_iidx
,
inner_iidx
+
nb_input_taps
))
)
inner_output_indices
.
append
(
list
(
range
(
inner_oidx
,
inner_oidx
+
nb_output_taps
))
)
outer_output_indices
.
append
(
outer_oidx
)
outer_iidx
+=
1
inner_iidx
+=
nb_input_taps
inner_oidx
+=
nb_output_taps
outer_oidx
+=
1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx
+=
self
.
info
.
n_shared_outs
# Handle nitsots variables
for
i
in
range
(
self
.
n_nit_sot
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([])
inner_output_indices
.
append
([
inner_oidx
])
outer_output_indices
.
append
(
outer_oidx
)
outer_iidx
+=
1
inner_iidx
+=
0
inner_oidx
+=
1
outer_oidx
+=
1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx
-=
self
.
info
.
n_shared_outs
+
self
.
n_nit_sot
# Handle shared states
for
i
in
range
(
self
.
info
.
n_shared_outs
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([
inner_oidx
])
outer_output_indices
.
append
(
outer_oidx
)
outer_iidx
+=
1
inner_iidx
+=
1
inner_oidx
+=
1
outer_oidx
+=
1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after* shared variables.
outer_iidx
+=
self
.
n_nit_sot
# Handle non-sequence inputs
# Note : the number of non-sequence inputs is not stored in self.info
# so it has to be inferred from the number of inner inputs that remain
# to be handled
for
i
in
range
(
len
(
self
.
inputs
)
-
inner_iidx
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([])
outer_output_indices
.
append
(
-
1
)
outer_iidx
+=
1
inner_iidx
+=
1
inner_oidx
+=
0
outer_oidx
+=
0
# With the global mapping inferred, the individual mappings
# can be produced
mappings
=
{
"outer_inp_from_outer_out"
:
{},
"inner_inp_from_outer_out"
:
{},
"inner_out_from_outer_out"
:
{},
"inner_inp_from_outer_inp"
:
{},
"inner_out_from_outer_inp"
:
{},
"outer_out_from_outer_inp"
:
{},
"outer_inp_from_inner_inp"
:
{},
"inner_out_from_inner_inp"
:
{},
"outer_out_from_inner_inp"
:
{},
"outer_inp_from_inner_out"
:
{},
"inner_inp_from_inner_out"
:
{},
"outer_out_from_inner_out"
:
{},
}
for
(
oinp
,
iinp
,
iout
,
oout
)
in
zip
(
outer_input_indices
,
inner_input_indices
,
inner_output_indices
,
outer_output_indices
,
):
if
oout
!=
-
1
:
mappings
[
"outer_inp_from_outer_out"
][
oout
]
=
oinp
mappings
[
"inner_inp_from_outer_out"
][
oout
]
=
iinp
mappings
[
"inner_out_from_outer_out"
][
oout
]
=
iout
if
oinp
!=
-
1
:
mappings
[
"inner_inp_from_outer_inp"
][
oinp
]
=
iinp
mappings
[
"inner_out_from_outer_inp"
][
oinp
]
=
iout
mappings
[
"outer_out_from_outer_inp"
][
oinp
]
=
oout
for
idx
in
iinp
:
mappings
[
"outer_inp_from_inner_inp"
][
idx
]
=
oinp
mappings
[
"inner_out_from_inner_inp"
][
idx
]
=
iout
mappings
[
"outer_out_from_inner_inp"
][
idx
]
=
oout
for
idx
in
iout
:
mappings
[
"outer_inp_from_inner_out"
][
idx
]
=
oinp
mappings
[
"inner_inp_from_inner_out"
][
idx
]
=
iinp
mappings
[
"outer_out_from_inner_out"
][
idx
]
=
oout
return
mappings
def
L_op
(
self
,
inputs
,
outs
,
dC_douts
):
def
L_op
(
self
,
inputs
,
outs
,
dC_douts
):
if
not
isinstance
(
outs
,
(
list
,
tuple
)):
if
not
isinstance
(
outs
,
(
list
,
tuple
)):
outs
=
[
outs
]
outs
=
[
outs
]
...
@@ -2217,6 +2213,7 @@ class Scan(Op):
...
@@ -2217,6 +2213,7 @@ class Scan(Op):
rval
=
[
gmp
.
get
(
p
,
None
)
for
p
in
diff_inputs
]
rval
=
[
gmp
.
get
(
p
,
None
)
for
p
in
diff_inputs
]
return
rval
return
rval
var_mappings
=
self
.
get_oinp_iinp_iout_oout_mappings
()
dC_dinps_t
=
[
None
for
inp
in
diff_inputs
]
dC_dinps_t
=
[
None
for
inp
in
diff_inputs
]
disconnected_dC_dinps_t
=
[
True
for
inp
in
diff_inputs
]
disconnected_dC_dinps_t
=
[
True
for
inp
in
diff_inputs
]
dC_dXts
=
[]
dC_dXts
=
[]
...
@@ -2254,7 +2251,7 @@ class Scan(Op):
...
@@ -2254,7 +2251,7 @@ class Scan(Op):
if
inp
in
graph_inputs
([
Xt
]):
if
inp
in
graph_inputs
([
Xt
]):
# Get the index of the outer output that to which
# Get the index of the outer output that to which
# the state variable 'inp' corresponds.
# the state variable 'inp' corresponds.
outer_oidx
=
self
.
var_mappings
[
"outer_out_from_inner_inp"
][
outer_oidx
=
var_mappings
[
"outer_out_from_inner_inp"
][
self
.
n_seqs
+
pos
self
.
n_seqs
+
pos
]
]
...
@@ -2307,7 +2304,7 @@ class Scan(Op):
...
@@ -2307,7 +2304,7 @@ class Scan(Op):
# Get the index of the first inner input corresponding to the
# Get the index of the first inner input corresponding to the
# pos-ieth inner input state
# pos-ieth inner input state
idxs
=
self
.
var_mappings
[
"inner_out_from_inner_inp"
][
self
.
n_seqs
+
pos
]
idxs
=
var_mappings
[
"inner_out_from_inner_inp"
][
self
.
n_seqs
+
pos
]
# Check if the pos-th input is associated with one of the
# Check if the pos-th input is associated with one of the
# recurrent states
# recurrent states
...
...
aesara/scan/utils.py
浏览文件 @
9a2280b8
...
@@ -1069,9 +1069,9 @@ class ScanArgs:
...
@@ -1069,9 +1069,9 @@ class ScanArgs:
@property
@property
def
var_mappings
(
self
):
def
var_mappings
(
self
):
from
aesara.scan.op
import
Scan
from
aesara.scan.op
import
Scan
MethodsMixin
return
Scan
.
get_oinp_iinp_iout_oout_mappings
(
self
)
return
Scan
MethodsMixin
.
get_oinp_iinp_iout_oout_mappings
(
self
)
@property
@property
def
field_names
(
self
):
def
field_names
(
self
):
...
...
doc/extending/scan.rst
浏览文件 @
9a2280b8
...
@@ -299,7 +299,7 @@ If the goal is to navigate between variables that are associated with the same
...
@@ -299,7 +299,7 @@ If the goal is to navigate between variables that are associated with the same
states (ex : going from an outer sequence input to the corresponding inner
states (ex : going from an outer sequence input to the corresponding inner
sequence input, going from an inner output associated with a recurrent state
sequence input, going from an inner output associated with a recurrent state
to the inner input(s) associated with that same recurrent state, etc.), then
to the inner input(s) associated with that same recurrent state, etc.), then
the `
`var_mappings`` attribute of the scan op
can be used.
the `
get_oinp_iinp_iout_oout_mappings_mappings` method of the `Scan` `Op`
can be used.
This attribute is a dictionary with 12 {key/value} pairs. The keys are listed
This attribute is a dictionary with 12 {key/value} pairs. The keys are listed
below :
below :
...
...
tests/scan/test_basic.py
浏览文件 @
9a2280b8
...
@@ -700,11 +700,12 @@ class TestScan:
...
@@ -700,11 +700,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results
# outer_inp_from_inner_inp produce the correct results
scan_node
=
a
.
owner
.
inputs
[
0
]
.
owner
scan_node
=
a
.
owner
.
inputs
[
0
]
.
owner
result
=
scan_node
.
op
.
var_mappings
[
"outer_inp_from_outer_out"
]
var_mappings
=
scan_node
.
op
.
get_oinp_iinp_iout_oout_mappings
()
result
=
var_mappings
[
"outer_inp_from_outer_out"
]
expected_result
=
{
0
:
1
,
1
:
2
}
expected_result
=
{
0
:
1
,
1
:
2
}
assert
result
==
expected_result
assert
result
==
expected_result
result
=
scan_node
.
op
.
var_mappings
[
"outer_inp_from_inner_inp"
]
result
=
var_mappings
[
"outer_inp_from_inner_inp"
]
expected_result
=
{
0
:
1
,
1
:
1
,
2
:
2
,
3
:
2
}
expected_result
=
{
0
:
1
,
1
:
1
,
2
:
2
,
3
:
2
}
assert
result
==
expected_result
assert
result
==
expected_result
...
@@ -733,11 +734,12 @@ class TestScan:
...
@@ -733,11 +734,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results
# outer_inp_from_inner_inp produce the correct results
scan_node
=
out
.
owner
.
inputs
[
0
]
.
owner
scan_node
=
out
.
owner
.
inputs
[
0
]
.
owner
result
=
scan_node
.
op
.
var_mappings
[
"outer_inp_from_outer_out"
]
var_mappings
=
scan_node
.
op
.
get_oinp_iinp_iout_oout_mappings
()
result
=
var_mappings
[
"outer_inp_from_outer_out"
]
expected_result
=
{
0
:
2
}
expected_result
=
{
0
:
2
}
assert
result
==
expected_result
assert
result
==
expected_result
result
=
scan_node
.
op
.
var_mappings
[
"outer_inp_from_inner_inp"
]
result
=
var_mappings
[
"outer_inp_from_inner_inp"
]
expected_result
=
{
0
:
1
,
1
:
2
,
2
:
2
}
expected_result
=
{
0
:
1
,
1
:
2
,
2
:
2
}
assert
result
==
expected_result
assert
result
==
expected_result
...
@@ -1685,11 +1687,12 @@ class TestScan:
...
@@ -1685,11 +1687,12 @@ class TestScan:
# outer_inp_from_inner_inp produce the correct results
# outer_inp_from_inner_inp produce the correct results
scan_node
=
list
(
updates
.
values
())[
0
]
.
owner
scan_node
=
list
(
updates
.
values
())[
0
]
.
owner
result
=
scan_node
.
op
.
var_mappings
[
"outer_inp_from_outer_out"
]
var_mappings
=
scan_node
.
op
.
get_oinp_iinp_iout_oout_mappings
()
result
=
var_mappings
[
"outer_inp_from_outer_out"
]
expected_result
=
{
0
:
3
,
1
:
5
,
2
:
4
}
expected_result
=
{
0
:
3
,
1
:
5
,
2
:
4
}
assert
result
==
expected_result
assert
result
==
expected_result
result
=
scan_node
.
op
.
var_mappings
[
"outer_inp_from_inner_inp"
]
result
=
var_mappings
[
"outer_inp_from_inner_inp"
]
expected_result
=
{
0
:
1
,
1
:
2
,
2
:
3
,
3
:
4
,
4
:
6
}
expected_result
=
{
0
:
1
,
1
:
2
,
2
:
3
,
3
:
4
,
4
:
6
}
assert
result
==
expected_result
assert
result
==
expected_result
...
@@ -3491,7 +3494,7 @@ class TestScan:
...
@@ -3491,7 +3494,7 @@ class TestScan:
# Compare the mappings with the expected values
# Compare the mappings with the expected values
scan_node
=
scan_outputs
[
0
]
.
owner
.
inputs
[
0
]
.
owner
scan_node
=
scan_outputs
[
0
]
.
owner
.
inputs
[
0
]
.
owner
mappings
=
scan_node
.
op
.
var_mappings
mappings
=
scan_node
.
op
.
get_oinp_iinp_iout_oout_mappings
()
assert
mappings
[
"inner_inp_from_outer_inp"
]
==
{
assert
mappings
[
"inner_inp_from_outer_inp"
]
==
{
0
:
[],
0
:
[],
...
...
tests/scan/test_utils.py
浏览文件 @
9a2280b8
...
@@ -253,7 +253,6 @@ def test_ScanArgs():
...
@@ -253,7 +253,6 @@ def test_ScanArgs():
# here we make sure it doesn't (and that all the inputs are the same)
# here we make sure it doesn't (and that all the inputs are the same)
assert
scan_args
.
inputs
==
scan_op
.
inputs
assert
scan_args
.
inputs
==
scan_op
.
inputs
assert
scan_args
.
info
==
scan_op
.
info
assert
scan_args
.
info
==
scan_op
.
info
assert
scan_args
.
var_mappings
==
scan_op
.
var_mappings
# Check that `ScanArgs.find_among_fields` works
# Check that `ScanArgs.find_among_fields` works
test_v
=
scan_op
.
inner_seqs
(
scan_op
.
inputs
)[
1
]
test_v
=
scan_op
.
inner_seqs
(
scan_op
.
inputs
)[
1
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论