Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
1d19c375
提交
1d19c375
authored
10月 29, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
11月 08, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow non-shared untraced SIT-SOT
上级
d10c61ba
隐藏空白字符变更
内嵌
并排
正在显示
8 个修改的文件
包含
362 行增加
和
221 行删除
+362
-221
scan.py
pytensor/link/jax/dispatch/scan.py
+12
-12
scan.py
pytensor/link/numba/dispatch/scan.py
+9
-9
basic.py
pytensor/scan/basic.py
+89
-39
op.py
pytensor/scan/op.py
+170
-126
rewriting.py
pytensor/scan/rewriting.py
+25
-17
utils.py
pytensor/scan/utils.py
+22
-16
test_scan.py
tests/link/numba/test_scan.py
+1
-1
test_basic.py
tests/scan/test_basic.py
+34
-1
没有找到文件。
pytensor/link/jax/dispatch/scan.py
浏览文件 @
1d19c375
...
...
@@ -60,23 +60,23 @@ def jax_funcify_Scan(op: Scan, **kwargs):
mit_mot_init
,
mit_sot_init
,
sit_sot_init
,
op
.
outer_
shared
(
outer_inputs
),
op
.
outer_
untraced_sit_sot
(
outer_inputs
),
op
.
outer_non_seqs
(
outer_inputs
),
)
# JAX `init`
def
jax_args_to_inner_func_args
(
carry
,
x
):
"""Convert JAX scan arguments into format expected by scan_inner_func.
scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT,
shared
, non_seqs)
scan(carry, x) -> scan_inner_func(seqs, MIT-SOT, SIT-SOT,
untraced SIT-SOT
, non_seqs)
"""
# `carry` contains all inner taps
, shared terms,
and non_seqs
# `carry` contains all inner taps and non_seqs
(
i
,
inner_mit_mot
,
inner_mit_sot
,
inner_sit_sot
,
inner_
shared
,
inner_
untraced_sit_sot
,
inner_non_seqs
,
)
=
carry
...
...
@@ -108,7 +108,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
*
mit_mot_flatten
,
*
mit_sot_flatten
,
*
inner_sit_sot
,
*
inner_
shared
,
*
inner_
untraced_sit_sot
,
*
inner_non_seqs
,
)
...
...
@@ -118,14 +118,14 @@ def jax_funcify_Scan(op: Scan, **kwargs):
):
"""Convert inner_scan_func outputs into format expected by JAX scan.
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs,
shared
_outs) -> (new_carry, ys)
old_carry + (MIT-SOT_outs, SIT-SOT_outs, NIT-SOT_outs,
untraced_SIT-SOT
_outs) -> (new_carry, ys)
"""
(
i
,
old_mit_mot
,
old_mit_sot
,
_old_sit_sot
,
_old_
shared
,
_old_
untraced_sit_sot
,
inner_non_seqs
,
)
=
old_carry
...
...
@@ -133,7 +133,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
new_mit_sot_vals
=
op
.
inner_mitsot_outs
(
inner_scan_outs
)
new_sit_sot
=
op
.
inner_sitsot_outs
(
inner_scan_outs
)
new_nit_sot
=
op
.
inner_nitsot_outs
(
inner_scan_outs
)
new_
shared
=
op
.
inner_shared
_outs
(
inner_scan_outs
)
new_
untraced_sit_sot
=
op
.
inner_untraced_sit_sot
_outs
(
inner_scan_outs
)
# New carry for next step
# Update MIT-MOT buffer at positions indicated by output taps
...
...
@@ -150,14 +150,14 @@ def jax_funcify_Scan(op: Scan, **kwargs):
old_mit_sot
,
new_mit_sot_vals
,
strict
=
True
)
]
# For SIT-SOT
, and shared
just pass along the new value
# For SIT-SOT just pass along the new value
# Non-sequences remain unchanged
new_carry
=
(
i
+
1
,
new_mit_mot
,
new_mit_sot
,
new_sit_sot
,
new_
shared
,
new_
untraced_sit_sot
,
inner_non_seqs
,
)
...
...
@@ -183,7 +183,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
final_mit_mot
,
_final_mit_sot
,
_final_sit_sot
,
final_
shared
,
final_
untraced_sit_sot
,
_final_non_seqs
,
),
traces
,
...
...
@@ -238,7 +238,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
scan_outs_final
=
[
*
final_mit_mot
,
*
get_partial_traces
(
traces
),
*
final_
shared
,
*
final_
untraced_sit_sot
,
]
if
len
(
scan_outs_final
)
==
1
:
...
...
pytensor/link/numba/dispatch/scan.py
浏览文件 @
1d19c375
...
...
@@ -108,19 +108,19 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
outer_in_mit_sot_names
=
op
.
outer_mitsot
(
outer_in_names
)
outer_in_sit_sot_names
=
op
.
outer_sitsot
(
outer_in_names
)
outer_in_nit_sot_names
=
op
.
outer_nitsot
(
outer_in_names
)
outer_in_
shared_names
=
op
.
outer_shared
(
outer_in_names
)
outer_in_
untraced_sit_sot_names
=
op
.
outer_untraced_sit_sot
(
outer_in_names
)
outer_in_non_seqs_names
=
op
.
outer_non_seqs
(
outer_in_names
)
# These are all the outer-input names that have produce outputs/have output
# taps (i.e. they have inner-outputs and corresponding outer-outputs).
# Outer-outputs are ordered as follows:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
shared
-outputs
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
untraced-sit-sot
-outputs
outer_in_outtap_names
=
(
outer_in_mit_mot_names
+
outer_in_mit_sot_names
+
outer_in_sit_sot_names
+
outer_in_nit_sot_names
+
outer_in_
shared
_names
+
outer_in_
untraced_sit_sot
_names
)
# We create distinct variables for/references to the storage arrays for
...
...
@@ -138,8 +138,10 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
for
outer_in_name
in
outer_in_nit_sot_names
:
outer_in_to_storage_name
[
outer_in_name
]
=
f
"{outer_in_name}_nitsot_storage"
for
outer_in_name
in
outer_in_shared_names
:
outer_in_to_storage_name
[
outer_in_name
]
=
f
"{outer_in_name}_shared_storage"
for
outer_in_name
in
outer_in_untraced_sit_sot_names
:
outer_in_to_storage_name
[
outer_in_name
]
=
(
f
"{outer_in_name}_untraced_sit_sot_storage"
)
outer_output_names
=
list
(
outer_in_to_storage_name
.
values
())
assert
len
(
outer_output_names
)
==
len
(
node
.
outputs
)
...
...
@@ -147,7 +149,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# Construct the inner-input expressions (e.g. indexed storage expressions)
# Inner-inputs are ordered as follows:
# sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
#
shared
-inputs + non-sequences.
#
untraced-sit-sot
-inputs + non-sequences.
temp_scalar_storage_alloc_stmts
:
list
[
str
]
=
[]
inner_in_exprs_scalar
:
list
[
str
]
=
[]
inner_in_exprs
:
list
[
str
]
=
[]
...
...
@@ -204,11 +206,9 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# Inner-outputs consist of:
# mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
#
shared
-outputs [+ while-condition]
#
untraced-sit-sot
-outputs [+ while-condition]
inner_output_names
=
[
f
"inner_out_{i}"
for
i
in
range
(
len
(
op
.
inner_outputs
))]
# inner_out_shared_names = op.inner_shared_outs(inner_output_names)
# The assignment statements that copy inner-outputs into the outer-outputs
# storage
inner_out_to_outer_in_stmts
:
list
[
str
]
=
[]
...
...
pytensor/scan/basic.py
浏览文件 @
1d19c375
import
typing
import
warnings
from
itertools
import
chain
...
...
@@ -11,6 +12,7 @@ from pytensor.graph.basic import Constant, Variable
from
pytensor.graph.op
import
get_test_value
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.traversal
import
explicit_graph_inputs
from
pytensor.graph.type
import
HasShape
from
pytensor.graph.utils
import
MissingInputError
,
TestValueError
from
pytensor.scan.op
import
Scan
,
ScanInfo
from
pytensor.scan.utils
import
expand_empty
,
safe_new
,
until
...
...
@@ -22,6 +24,10 @@ from pytensor.tensor.type import TensorType, integer_dtypes
from
pytensor.updates
import
OrderedUpdates
if
typing
.
TYPE_CHECKING
:
from
pytensor.tensor.type
import
TensorVariable
def
get_updates_and_outputs
(
ls
):
"""Recognize and order the updates, outputs, and stopping condition for a `Scan`.
...
...
@@ -469,7 +475,7 @@ def scan(
# Make sure we get rid of numpy arrays or ints or anything like that
# passed as inputs to scan
non_seqs
=
[]
non_seqs
:
list
[
Variable
]
=
[]
for
elem
in
wrap_into_list
(
non_sequences
):
if
not
isinstance
(
elem
,
Variable
):
non_seqs
.
append
(
pt
.
as_tensor_variable
(
elem
))
...
...
@@ -685,10 +691,10 @@ def scan(
# MIT_MOT -- not provided by the user only by the grad function
n_mit_mot
=
0
mit_mot_scan_inputs
=
[]
mit_mot_inner_inputs
=
[]
mit_mot_inner_outputs
=
[]
mit_mot_out_slices
=
[]
mit_mot_scan_inputs
:
list
[
TensorVariable
]
=
[]
mit_mot_inner_inputs
:
list
[
TensorVariable
]
=
[]
mit_mot_inner_outputs
:
list
[
TensorVariable
]
=
[]
mit_mot_out_slices
:
list
[
TensorVariable
]
=
[]
# SIT_SOT -- provided by the user
n_mit_sot
=
0
...
...
@@ -706,6 +712,12 @@ def scan(
sit_sot_inner_outputs
=
[]
sit_sot_rightOrder
=
[]
n_untraced_sit_sot_outs
=
0
untraced_sit_sot_scan_inputs
=
[]
untraced_sit_sot_inner_inputs
=
[]
untraced_sit_sot_inner_outputs
=
[]
untraced_sit_sot_rightOrder
=
[]
# go through outputs picking up time slices as needed
for
i
,
init_out
in
enumerate
(
outs_info
):
# Note that our convention dictates that if an output uses
...
...
@@ -741,17 +753,35 @@ def scan(
# We need now to allocate space for storing the output and copy
# the initial state over. We do this using the expand function
# defined in scan utils
sit_sot_scan_inputs
.
append
(
expand_empty
(
shape_padleft
(
actual_arg
),
actual_n_steps
,
if
isinstance
(
actual_arg
.
type
,
HasShape
):
sit_sot_scan_inputs
.
append
(
expand_empty
(
shape_padleft
(
actual_arg
),
actual_n_steps
,
)
)
)
sit_sot_inner_slices
.
append
(
actual_arg
)
sit_sot_inner_slices
.
append
(
actual_arg
)
sit_sot_inner_inputs
.
append
(
arg
)
sit_sot_rightOrder
.
append
(
i
)
n_sit_sot
+=
1
sit_sot_inner_inputs
.
append
(
arg
)
sit_sot_rightOrder
.
append
(
i
)
n_sit_sot
+=
1
else
:
# Assume variables without shape cannot be stacked (e.g., RNG variables)
# Because this is new, issue a warning to inform the user, except for RNG, which were the main reason for this feature
from
pytensor.tensor.random.type
import
RandomType
if
not
isinstance
(
arg
.
type
,
RandomType
):
warnings
.
warn
(
(
f
"Output {actual_arg} (index {i}) with type {actual_arg.type} will be treated as untraced variable in scan. "
"Only the last value will be returned, not the entire sequence."
),
UserWarning
,
)
untraced_sit_sot_scan_inputs
.
append
(
actual_arg
)
untraced_sit_sot_inner_inputs
.
append
(
arg
)
n_untraced_sit_sot_outs
+=
1
untraced_sit_sot_rightOrder
.
append
(
i
)
elif
init_out
.
get
(
"taps"
,
None
):
if
np
.
any
(
np
.
array
(
init_out
.
get
(
"taps"
,
[]))
>
0
):
...
...
@@ -802,10 +832,11 @@ def scan(
# a map); in that case we do not have to do anything ..
# Re-order args
max_mit_sot
=
np
.
max
([
-
1
,
*
mit_sot_rightOrder
])
+
1
max_sit_sot
=
np
.
max
([
-
1
,
*
sit_sot_rightOrder
])
+
1
n_elems
=
np
.
max
([
max_mit_sot
,
max_sit_sot
])
_ordered_args
=
[[]
for
x
in
range
(
n_elems
)]
max_mit_sot
=
max
(
mit_sot_rightOrder
,
default
=-
1
)
+
1
max_sit_sot
=
max
(
sit_sot_rightOrder
,
default
=-
1
)
+
1
max_untraced_sit_sot_outs
=
max
(
untraced_sit_sot_rightOrder
,
default
=-
1
)
+
1
n_elems
=
np
.
max
((
max_mit_sot
,
max_sit_sot
,
max_untraced_sit_sot_outs
))
_ordered_args
:
list
[
list
[
Variable
]]
=
[[]
for
x
in
range
(
n_elems
)]
offset
=
0
for
idx
in
range
(
n_mit_sot
):
n_inputs
=
len
(
mit_sot_tap_array
[
idx
])
...
...
@@ -825,6 +856,11 @@ def scan(
else
:
_ordered_args
[
sit_sot_rightOrder
[
idx
]]
=
[
sit_sot_inner_inputs
[
idx
]]
for
idx
in
range
(
n_untraced_sit_sot_outs
):
_ordered_args
[
untraced_sit_sot_rightOrder
[
idx
]]
=
[
untraced_sit_sot_inner_inputs
[
idx
]
]
ordered_args
=
list
(
chain
.
from_iterable
(
_ordered_args
))
if
single_step_requested
:
args
=
inner_slices
+
ordered_args
+
non_seqs
...
...
@@ -939,18 +975,19 @@ def scan(
if
"taps"
in
out
and
out
[
"taps"
]
!=
[
-
1
]:
mit_sot_inner_outputs
.
append
(
outputs
[
i
])
# Step 5.2 Outputs with tap equal to -1
# Step 5.2 Outputs with tap equal to -1
(traced and untraced)
for
i
,
out
in
enumerate
(
outs_info
):
if
"taps"
in
out
and
out
[
"taps"
]
==
[
-
1
]:
sit_sot_inner_outputs
.
append
(
outputs
[
i
])
output
=
outputs
[
i
]
if
isinstance
(
output
.
type
,
HasShape
):
sit_sot_inner_outputs
.
append
(
output
)
else
:
untraced_sit_sot_inner_outputs
.
append
(
output
)
# Step 5.3 Outputs that correspond to update rules of shared variables
# This whole special logic for shared variables is deprecated
sit_sot_shared
:
list
[
Variable
]
=
[]
inner_replacements
=
{}
n_shared_outs
=
0
shared_scan_inputs
=
[]
shared_inner_inputs
=
[]
shared_inner_outputs
=
[]
sit_sot_shared
=
[]
no_update_shared_inputs
=
[]
for
input
in
dummy_inputs
:
if
not
isinstance
(
input
.
variable
,
SharedVariable
):
...
...
@@ -976,8 +1013,8 @@ def scan(
new_var
=
safe_new
(
input
.
variable
)
if
getattr
(
input
.
variable
,
"name"
,
None
)
is
not
None
:
new_var
.
name
=
input
.
variable
.
name
+
"
_copy"
if
input
.
variable
.
name
is
not
None
:
new_var
.
name
=
f
"{input.variable.name}
_copy"
inner_replacements
[
input
.
variable
]
=
new_var
...
...
@@ -1003,10 +1040,10 @@ def scan(
sit_sot_shared
.
append
(
input
.
variable
)
else
:
shared
_inner_inputs
.
append
(
new_var
)
shared
_scan_inputs
.
append
(
input
.
variable
)
shared
_inner_outputs
.
append
(
input
.
update
)
n_
shared
_outs
+=
1
untraced_sit_sot
_inner_inputs
.
append
(
new_var
)
untraced_sit_sot
_scan_inputs
.
append
(
input
.
variable
)
untraced_sit_sot
_inner_outputs
.
append
(
input
.
update
)
n_
untraced_sit_sot
_outs
+=
1
else
:
no_update_shared_inputs
.
append
(
input
)
...
...
@@ -1071,7 +1108,7 @@ def scan(
+
mit_mot_inner_inputs
+
mit_sot_inner_inputs
+
sit_sot_inner_inputs
+
shared
_inner_inputs
+
untraced_sit_sot
_inner_inputs
+
other_shared_inner_args
+
other_inner_args
)
...
...
@@ -1081,7 +1118,7 @@ def scan(
+
mit_sot_inner_outputs
+
sit_sot_inner_outputs
+
nit_sot_inner_outputs
+
shared
_inner_outputs
+
untraced_sit_sot
_inner_outputs
)
if
condition
is
not
None
:
inner_outs
.
append
(
condition
)
...
...
@@ -1101,7 +1138,7 @@ def scan(
mit_mot_out_slices
=
tuple
(
tuple
(
v
)
for
v
in
mit_mot_out_slices
),
mit_sot_in_slices
=
tuple
(
tuple
(
v
)
for
v
in
mit_sot_tap_array
),
sit_sot_in_slices
=
tuple
((
-
1
,)
for
x
in
range
(
n_sit_sot
)),
n_
shared_outs
=
n_shared
_outs
,
n_
untraced_sit_sot_outs
=
n_untraced_sit_sot
_outs
,
n_nit_sot
=
n_nit_sot
,
n_non_seqs
=
len
(
other_shared_inner_args
)
+
len
(
other_inner_args
),
as_while
=
as_while
,
...
...
@@ -1127,7 +1164,7 @@ def scan(
+
mit_mot_scan_inputs
+
mit_sot_scan_inputs
+
sit_sot_scan_inputs
+
shared
_scan_inputs
+
untraced_sit_sot
_scan_inputs
+
[
actual_n_steps
for
x
in
range
(
n_nit_sot
)]
+
other_shared_scan_args
+
other_scan_args
...
...
@@ -1173,13 +1210,26 @@ def scan(
nit_sot_outs
=
remove_dimensions
(
scan_outs
[
offset
:
offset
+
n_nit_sot
])
offset
+=
n_nit_sot
for
idx
,
update_rule
in
enumerate
(
scan_outs
[
offset
:
offset
+
n_shared_outs
]):
update_map
[
shared_scan_inputs
[
idx
]]
=
update_rule
_scan_out_list
=
mit_sot_outs
+
sit_sot_outs
+
nit_sot_outs
# Support for explicit untraced sit_sot
n_explicit_untraced_sit_sot_outs
=
len
(
untraced_sit_sot_rightOrder
)
untraced_sit_sot_outs
=
scan_outs
[
offset
:
offset
+
n_explicit_untraced_sit_sot_outs
]
offset
+=
n_explicit_untraced_sit_sot_outs
for
idx
,
update_rule
in
enumerate
(
scan_outs
[
offset
:]):
update_map
[
untraced_sit_sot_scan_inputs
[
idx
]]
=
update_rule
_scan_out_list
=
mit_sot_outs
+
sit_sot_outs
+
nit_sot_outs
+
untraced_sit_sot_outs
# Step 10. I need to reorder the outputs to be in the order expected by
# the user
rightOrder
=
mit_sot_rightOrder
+
sit_sot_rightOrder
+
nit_sot_rightOrder
rightOrder
=
(
mit_sot_rightOrder
+
sit_sot_rightOrder
+
untraced_sit_sot_rightOrder
+
nit_sot_rightOrder
)
scan_out_list
=
[
None
]
*
len
(
rightOrder
)
for
idx
,
pos
in
enumerate
(
rightOrder
):
if
pos
>=
0
:
...
...
pytensor/scan/op.py
浏览文件 @
1d19c375
...
...
@@ -46,6 +46,7 @@ relies on the following elements to work properly :
import
dataclasses
import
logging
import
time
import
warnings
from
collections.abc
import
Callable
,
Iterable
from
copy
import
copy
from
itertools
import
chain
,
product
...
...
@@ -208,10 +209,19 @@ class ScanInfo:
mit_sot_in_slices
:
tuple
sit_sot_in_slices
:
tuple
n_nit_sot
:
int
n_
shared
_outs
:
int
n_
untraced_sit_sot
_outs
:
int
n_non_seqs
:
int
as_while
:
bool
@property
def
n_shared_outs
(
self
):
warnings
.
warn
(
"The 'n_shared_outs' property is deprecated. Use 'n_untraced_sit_sot_outs' instead."
,
DeprecationWarning
,
stacklevel
=
2
,
)
return
self
.
n_untraced_sit_sot_outs
@property
def
n_mit_mot
(
self
):
return
len
(
self
.
mit_mot_in_slices
)
...
...
@@ -239,7 +249,7 @@ class ScanInfo:
+
sum
(
len
(
x
)
for
x
in
self
.
mit_mot_in_slices
)
+
sum
(
len
(
x
)
for
x
in
self
.
mit_sot_in_slices
)
+
self
.
n_sit_sot
+
self
.
n_
shared
_outs
+
self
.
n_
untraced_sit_sot
_outs
+
self
.
n_non_seqs
)
...
...
@@ -250,7 +260,7 @@ class ScanInfo:
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_
shared
_outs
+
self
.
n_
untraced_sit_sot
_outs
+
int
(
self
.
as_while
)
)
...
...
@@ -263,7 +273,7 @@ class ScanInfo:
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_
shared
_outs
+
self
.
n_
untraced_sit_sot
_outs
+
self
.
n_non_seqs
)
...
...
@@ -274,7 +284,7 @@ class ScanInfo:
+
self
.
n_mit_sot
+
self
.
n_sit_sot
+
self
.
n_nit_sot
+
self
.
n_
shared
_outs
+
self
.
n_
untraced_sit_sot
_outs
)
...
...
@@ -381,7 +391,7 @@ class ScanMethodsMixin:
+
self
.
info
.
n_mit_mot
+
self
.
info
.
n_mit_sot
+
self
.
info
.
n_sit_sot
+
self
.
info
.
n_
shared
_outs
+
self
.
info
.
n_
untraced_sit_sot
_outs
)
return
list_inputs
[
offset
:
offset
+
self
.
info
.
n_nit_sot
]
...
...
@@ -394,15 +404,23 @@ class ScanMethodsMixin:
offset
=
self
.
info
.
n_mit_mot
+
self
.
info
.
n_mit_sot
+
self
.
info
.
n_sit_sot
return
list_outputs
[
offset
:
offset
+
self
.
info
.
n_nit_sot
]
def
inner_
shared
(
self
,
list_inputs
):
def
inner_
untraced_sit_sot
(
self
,
list_inputs
):
n_taps_upto_sit_sot
=
sum
(
len
(
x
)
for
x
in
chain
(
self
.
info
.
mit_mot_in_slices
,
self
.
info
.
mit_sot_in_slices
)
)
offset
=
self
.
info
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
info
.
n_sit_sot
return
list_inputs
[
offset
:
offset
+
self
.
info
.
n_
shared
_outs
]
return
list_inputs
[
offset
:
offset
+
self
.
info
.
n_
untraced_sit_sot
_outs
]
def
outer_shared
(
self
,
list_inputs
):
def
inner_shared
(
self
,
list_inputs
):
warnings
.
warn
(
"The 'inner_shared' method is deprecated. Use 'inner_untraced_sit_sot' instead."
,
DeprecationWarning
,
stacklevel
=
2
,
)
return
self
.
inner_untraced_sit_sot
(
list_inputs
)
def
outer_untraced_sit_sot
(
self
,
list_inputs
):
offset
=
(
1
+
self
.
info
.
n_seqs
...
...
@@ -410,23 +428,47 @@ class ScanMethodsMixin:
+
self
.
info
.
n_mit_sot
+
self
.
info
.
n_sit_sot
)
return
list_inputs
[
offset
:
offset
+
self
.
info
.
n_
shared
_outs
]
return
list_inputs
[
offset
:
offset
+
self
.
info
.
n_
untraced_sit_sot
_outs
]
def
inner_shared_outs
(
self
,
list_outputs
):
def
outer_shared
(
self
,
list_inputs
):
warnings
.
warn
(
"The 'outer_shared' method is deprecated. Use 'outer_untraced_sit_sot' instead."
,
DeprecationWarning
,
stacklevel
=
2
,
)
return
self
.
outer_untraced_sit_sot
(
list_inputs
)
def
inner_untraced_sit_sot_outs
(
self
,
list_outputs
):
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
info
.
mit_mot_out_slices
)
offset
=
(
self
.
info
.
n_mit_sot
+
n_taps
+
self
.
info
.
n_sit_sot
+
self
.
info
.
n_nit_sot
)
return
list_outputs
[
offset
:
offset
+
self
.
info
.
n_
shared
_outs
]
return
list_outputs
[
offset
:
offset
+
self
.
info
.
n_
untraced_sit_sot
_outs
]
def
outer_shared_outs
(
self
,
list_outputs
):
def
inner_shared_outs
(
self
,
list_outputs
):
warnings
.
warn
(
"The 'inner_shared_outs' method is deprecated. Use 'inner_untraced_sit_sot_outs' instead."
,
DeprecationWarning
,
stacklevel
=
2
,
)
return
self
.
inner_untraced_sit_sot_outs
(
list_outputs
)
def
outer_untraced_sit_sot_outs
(
self
,
list_outputs
):
offset
=
(
self
.
info
.
n_mit_mot
+
self
.
info
.
n_mit_sot
+
self
.
info
.
n_sit_sot
+
self
.
info
.
n_nit_sot
)
return
list_outputs
[
offset
:
offset
+
self
.
info
.
n_shared_outs
]
return
list_outputs
[
offset
:
offset
+
self
.
info
.
n_untraced_sit_sot_outs
]
def
outer_shared_outs
(
self
,
list_outputs
):
warnings
.
warn
(
"The 'outer_shared_outs' method is deprecated. Use 'outer_untraced_sit_sot_outs' instead."
,
DeprecationWarning
,
stacklevel
=
2
,
)
return
self
.
outer_untraced_sit_sot_outs
(
list_outputs
)
def
inner_non_seqs
(
self
,
list_inputs
):
n_taps_upto_sit_sot
=
sum
(
...
...
@@ -437,7 +479,7 @@ class ScanMethodsMixin:
self
.
info
.
n_seqs
+
n_taps_upto_sit_sot
+
self
.
info
.
n_sit_sot
+
self
.
info
.
n_
shared
_outs
+
self
.
info
.
n_
untraced_sit_sot
_outs
)
return
list_inputs
[
offset
:]
...
...
@@ -449,7 +491,7 @@ class ScanMethodsMixin:
+
self
.
info
.
n_mit_sot
+
self
.
info
.
n_sit_sot
+
self
.
info
.
n_nit_sot
+
self
.
info
.
n_
shared
_outs
+
self
.
info
.
n_
untraced_sit_sot
_outs
)
return
list_inputs
[
offset
:]
...
...
@@ -525,8 +567,8 @@ class ScanMethodsMixin:
outer_oidx
+=
1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after*
shared
variables.
outer_iidx
+=
self
.
info
.
n_
shared
_outs
# nitsots come *after*
untraced_sitsot
variables.
outer_iidx
+=
self
.
info
.
n_
untraced_sit_sot
_outs
# Handle nitsots variables
for
i
in
range
(
self
.
info
.
n_nit_sot
):
...
...
@@ -541,11 +583,11 @@ class ScanMethodsMixin:
outer_oidx
+=
1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after*
shared
variables.
outer_iidx
-=
self
.
info
.
n_
shared
_outs
+
self
.
info
.
n_nit_sot
# nitsots come *after*
untraced_sit_sot
variables.
outer_iidx
-=
self
.
info
.
n_
untraced_sit_sot
_outs
+
self
.
info
.
n_nit_sot
# Handle
shared
states
for
i
in
range
(
self
.
info
.
n_
shared
_outs
):
# Handle
untraced_sitsot
states
for
i
in
range
(
self
.
info
.
n_
untraced_sit_sot
_outs
):
outer_input_indices
.
append
(
outer_iidx
)
inner_input_indices
.
append
([
inner_iidx
])
inner_output_indices
.
append
([
inner_oidx
])
...
...
@@ -557,7 +599,7 @@ class ScanMethodsMixin:
outer_oidx
+=
1
# This is needed because, for outer inputs (and for outer inputs only)
# nitsots come *after*
shared
variables.
# nitsots come *after*
untraced_sitsot
variables.
outer_iidx
+=
self
.
info
.
n_nit_sot
# Handle non-sequence inputs
...
...
@@ -708,7 +750,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
Inputs of the inner function of `Scan`.
These take the following general form:
sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs + shared-inputs + non-sequences
sequences + mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
untraced-sit-sot-inputs +
shared-inputs + non-sequences
where each term is a list of `Variable`\s.
...
...
@@ -716,7 +758,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
Outputs of the inner function of `Scan`.
These take the following general form:
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
shared
-outputs [+ while-condition]
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs + nit-sots +
untraced-sit-sot
-outputs [+ while-condition]
where each term is a list of `Variable`\s.
...
...
@@ -817,7 +859,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
typeConstructor
((
None
,
*
o
.
type
.
shape
),
o
.
type
.
dtype
)
)
#
shared
outputs + possibly the ending condition
#
untraced_sit_sot
outputs + possibly the ending condition
for
o
in
self
.
fgraph
.
outputs
[
end
:]:
self
.
output_types
.
append
(
o
.
type
)
...
...
@@ -836,10 +878,12 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
]
self
.
mintaps
+=
[
0
for
x
in
range
(
info
.
n_nit_sot
)]
self
.
seqs_arg_offset
=
1
+
info
.
n_seqs
self
.
shared
_arg_offset
=
(
self
.
untraced_sit_sot
_arg_offset
=
(
self
.
seqs_arg_offset
+
info
.
n_mit_mot
+
info
.
n_mit_sot
+
info
.
n_sit_sot
)
self
.
nit_sot_arg_offset
=
self
.
shared_arg_offset
+
info
.
n_shared_outs
self
.
nit_sot_arg_offset
=
(
self
.
untraced_sit_sot_arg_offset
+
info
.
n_untraced_sit_sot_outs
)
# XXX: This doesn't include `info.n_nit_sot`s, so it's really a count
# of the number of outputs generated by taps with inputs
self
.
n_outs
=
info
.
n_mit_mot
+
info
.
n_mit_sot
+
info
.
n_sit_sot
...
...
@@ -908,7 +952,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
sequences +
mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
shared-inputs +
untraced-sit-sot-inputs + shared-inputs
nit-sots +
non-sequences
...
...
@@ -923,7 +967,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
[n_steps] +
sequences +
mit-mot-inputs + mit-sot-inputs + sit-sot-inputs +
shared-inputs +
untraced-sit-sot-inputs + shared-inputs
nit-sots +
non-sequences
...
...
@@ -931,7 +975,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mit-mot-outputs + mit-sot-outputs + sit-sot-outputs +
nit-sots +
shared
-outputs
untraced-sit-sot
-outputs
These outer-outputs essentially follow the same form as their
corresponding inner-outputs, excluding the final "while" condition
...
...
@@ -949,7 +993,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+
len
(
self
.
info
.
mit_mot_in_slices
)
+
len
(
self
.
info
.
mit_sot_in_slices
)
+
len
(
self
.
inner_sitsot
(
self
.
inner_inputs
))
+
len
(
self
.
inner_
shared
(
self
.
inner_inputs
))
+
len
(
self
.
inner_
untraced_sit_sot
(
self
.
inner_inputs
))
+
len
(
self
.
inner_non_seqs
(
self
.
inner_inputs
))
)
...
...
@@ -1134,60 +1178,60 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
)
argoffset
+=
len
(
self
.
outer_sitsot
(
inputs
))
# Check that the
shared
variable and their update rule have the same
# Check that the
untraced (u) sit-sot
variable and their update rule have the same
# dtype. Maybe even same type ?!
for
idx
,
(
inner_
shared
,
inner_shared_out
,
_outer_shared
)
in
enumerate
(
for
idx
,
(
inner_
u_sitsot
,
inner_u_sitsot_out
,
_outer_u_sitsot
)
in
enumerate
(
zip
(
self
.
inner_
shared
(
self
.
inner_inputs
),
self
.
inner_
shared
_outs
(
self
.
inner_outputs
),
self
.
outer_
shared
(
inputs
),
self
.
inner_
untraced_sit_sot
(
self
.
inner_inputs
),
self
.
inner_
untraced_sit_sot
_outs
(
self
.
inner_outputs
),
self
.
outer_
untraced_sit_sot
(
inputs
),
strict
=
True
,
)
):
outer_
shared
=
copy_var_format
(
_outer_shared
,
as_var
=
inner_shared
)
new_inputs
.
append
(
outer_
shared
)
outer_
u_sitsot
=
copy_var_format
(
_outer_u_sitsot
,
as_var
=
inner_u_sitsot
)
new_inputs
.
append
(
outer_
u_sitsot
)
if
(
hasattr
(
outer_
shared
,
"dtype"
)
and
outer_
shared
.
dtype
!=
inner_shared
_out
.
dtype
hasattr
(
outer_
u_sitsot
,
"dtype"
)
and
outer_
u_sitsot
.
dtype
!=
inner_u_sitsot
_out
.
dtype
):
raise
ValueError
(
err_msg2
%
(
str
(
outer_
shared
),
str
(
outer_
u_sitsot
),
idx
+
argoffset
,
outer_
shared
.
dtype
,
inner_
shared
_out
.
dtype
,
outer_
u_sitsot
.
dtype
,
inner_
u_sitsot
_out
.
dtype
,
)
)
if
(
hasattr
(
outer_
shared
,
"dtype"
)
and
outer_
shared
.
ndim
!=
inner_shared
_out
.
ndim
hasattr
(
outer_
u_sitsot
,
"dtype"
)
and
outer_
u_sitsot
.
ndim
!=
inner_u_sitsot
_out
.
ndim
):
raise
ValueError
(
err_msg3
%
(
str
(
outer_
shared
),
str
(
outer_
u_sitsot
),
idx
+
argoffset
,
outer_
shared
.
ndim
,
inner_
shared
_out
.
ndim
,
outer_
u_sitsot
.
ndim
,
inner_
u_sitsot
_out
.
ndim
,
)
)
if
hasattr
(
outer_
shared
,
"dtype"
)
and
(
outer_
shared
.
dtype
!=
inner_shared
.
dtype
or
outer_
shared
.
ndim
!=
inner_shared
.
ndim
if
hasattr
(
outer_
u_sitsot
,
"dtype"
)
and
(
outer_
u_sitsot
.
dtype
!=
inner_u_sitsot
.
dtype
or
outer_
u_sitsot
.
ndim
!=
inner_u_sitsot
.
ndim
):
raise
ValueError
(
err_msg1
%
(
"initial state (outputs_info in scan nomenclature) "
,
str
(
outer_
shared
),
str
(
outer_
u_sitsot
),
argoffset
+
idx
,
outer_
shared
.
dtype
,
outer_
shared
.
ndim
,
str
(
inner_
shared
),
inner_
shared
.
dtype
,
inner_
shared
.
ndim
,
outer_
u_sitsot
.
dtype
,
outer_
u_sitsot
.
ndim
,
str
(
inner_
u_sitsot
),
inner_
u_sitsot
.
dtype
,
inner_
u_sitsot
.
ndim
,
)
)
# We do not need to call `copy_var_format` on outer_nisot arguments.
...
...
@@ -1585,7 +1629,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
try
:
t_fn
,
n_steps
=
scan_perform_ext
.
perform
(
self
.
info
.
n_
shared
_outs
,
self
.
info
.
n_
untraced_sit_sot
_outs
,
self
.
info
.
n_mit_mot_outs
,
self
.
info
.
n_seqs
,
self
.
info
.
n_mit_mot
,
...
...
@@ -1719,7 +1763,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# The length of each output
store_steps
=
[
arg
.
shape
[
0
]
for
arg
in
inputs
[
self
.
seqs_arg_offset
:
self
.
shared
_arg_offset
]
for
arg
in
inputs
[
self
.
seqs_arg_offset
:
self
.
untraced_sit_sot
_arg_offset
]
]
store_steps
+=
list
(
inputs
[
self
.
nit_sot_arg_offset
:
self
.
nit_sot_arg_offset
+
info
.
n_nit_sot
]
...
...
@@ -1784,7 +1828,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
info
.
sit_sot_in_slices
,
)
)
+
info
.
n_
shared
_outs
+
info
.
n_
untraced_sit_sot
_outs
)
for
idx
in
range
(
len
(
other_args
)):
inner_input_storage
[
idx
+
offset
]
.
storage
[
0
]
=
other_args
[
idx
]
...
...
@@ -1827,14 +1871,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
]
offset
+=
1
a_offset
=
self
.
shared
_arg_offset
a_offset
=
self
.
untraced_sit_sot
_arg_offset
o_offset
=
self
.
n_outs
+
info
.
n_nit_sot
if
i
==
0
:
for
j
in
range
(
info
.
n_
shared
_outs
):
for
j
in
range
(
info
.
n_
untraced_sit_sot
_outs
):
inner_input_storage
[
offset
]
.
storage
[
0
]
=
inputs
[
a_offset
+
j
]
offset
+=
1
else
:
for
j
in
range
(
info
.
n_
shared
_outs
):
for
j
in
range
(
info
.
n_
untraced_sit_sot
_outs
):
inner_input_storage
[
offset
]
.
storage
[
0
]
=
output_storage
[
o_offset
+
j
][
0
]
...
...
@@ -1866,14 +1910,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
for
idx
in
range
(
self
.
n_outs
+
info
.
n_nit_sot
-
info
.
n_mit_mot
):
inner_output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
# 4.3. Collect slices for
shared
outputs
# 4.3. Collect slices for
untraced sitsot
outputs
offset
+=
self
.
n_outs
+
info
.
n_nit_sot
-
info
.
n_mit_mot
for
idx
in
range
(
info
.
n_
shared
_outs
):
for
idx
in
range
(
info
.
n_
untraced_sit_sot
_outs
):
inner_output_storage
[
idx
+
offset
]
.
storage
[
0
]
=
None
# 4.4. If there is a condition add it to the mix
if
info
.
as_while
:
pdx
=
offset
+
info
.
n_
shared
_outs
pdx
=
offset
+
info
.
n_
untraced_sit_sot
_outs
inner_output_storage
[
pdx
]
.
storage
[
0
]
=
None
# 4.5. Keep a reference to the variables (ndarrays,
...
...
@@ -1942,7 +1986,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dt_fn
=
time
.
perf_counter
()
-
t0_fn
if
info
.
as_while
:
pdx
=
offset
+
info
.
n_
shared
_outs
pdx
=
offset
+
info
.
n_
untraced_sit_sot
_outs
cond
=
inner_output_storage
[
pdx
]
.
storage
[
0
]
==
0
t_fn
+=
dt_fn
...
...
@@ -2089,10 +2133,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
j
+
offset_out
]
.
storage
[
0
]
# 5.6 Copy over the values for outputs corresponding to
shared
# 5.6 Copy over the values for outputs corresponding to
untraced sitsot
# variables
begin
=
end
end
+=
info
.
n_
shared
_outs
end
+=
info
.
n_
untraced_sit_sot
_outs
for
j
in
range
(
begin
,
end
):
jout
=
j
+
offset_out
output_storage
[
j
][
0
]
=
inner_output_storage
[
jout
]
.
storage
[
0
]
...
...
@@ -2240,13 +2284,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# out_equivalent[self.inner_inputs[inner_inp_idx]] = corresponding_tap
outer_inp_idx
+=
1
#
shared_o
uts
#
untraced sit_sot outp
uts
offset
=
1
+
info
.
n_seqs
+
n_outs
for
idx
in
range
(
info
.
n_
shared
_outs
):
for
idx
in
range
(
info
.
n_
untraced_sit_sot
_outs
):
outs_shape
+=
[
input_shapes
[
idx
+
offset
]]
# non_sequences
offset
+=
info
.
n_nit_sot
+
info
.
n_
shared
_outs
offset
+=
info
.
n_nit_sot
+
info
.
n_
untraced_sit_sot
_outs
inner_ins_shapes
=
seqs_shape
+
outs_shape
+
input_shapes
[
offset
:]
assert
len
(
inner_ins_shapes
)
==
len
(
self
.
inner_inputs
)
...
...
@@ -2288,7 +2332,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# in the inner function.
r
=
node
.
outputs
[
n_outs
+
x
]
assert
r
.
ndim
==
1
+
len
(
out_shape_x
)
shp
=
[
node
.
inputs
[
offset
+
info
.
n_
shared
_outs
+
x
]]
shp
=
[
node
.
inputs
[
offset
+
info
.
n_
untraced_sit_sot
_outs
+
x
]]
for
i
,
shp_i
in
zip
(
range
(
1
,
r
.
ndim
),
out_shape_x
,
strict
=
True
):
# Validate shp_i. v_shape_i is either None (if invalid),
# or a (variable, Boolean) tuple. The Boolean indicates
...
...
@@ -2305,7 +2349,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
shp
.
append
(
v_shp_i
[
0
])
scan_outs
.
append
(
tuple
(
shp
))
scan_outs
+=
list
(
input_shapes
[
offset
:
offset
+
info
.
n_
shared
_outs
])
scan_outs
+=
list
(
input_shapes
[
offset
:
offset
+
info
.
n_
untraced_sit_sot
_outs
])
# if we are dealing with a repeat-until, then we do not know the
# leading dimension so we replace it for every entry with Shape_i
if
info
.
as_while
:
...
...
@@ -2735,7 +2779,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mitmot_inp_taps
.
append
([])
mitmot_out_taps
.
append
([])
undefined_msg
=
None
through_
shar
ed
=
False
through_
untrac
ed
=
False
disconnected
=
True
for
mit_mot_out_slice
in
info
.
mit_mot_out_slices
[
idx
]:
...
...
@@ -2779,9 +2823,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
disconnected
&=
disconnected_dC_dinps_t
[
ins_pos
]
through_
shar
ed
=
any
(
through_
untrac
ed
=
any
(
_sh
in
graph_inputs
([
dC_dinps_t
[
ins_pos
]])
for
_sh
in
self
.
inner_
shared
(
self_inputs
)
for
_sh
in
self
.
inner_
untraced_sit_sot
(
self_inputs
)
)
ins_pos
+=
1
...
...
@@ -2795,8 +2839,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if
undefined_msg
:
type_outs
.
append
(
undefined_msg
)
elif
through_
shar
ed
:
type_outs
.
append
(
"through_
shar
ed"
)
elif
through_
untrac
ed
:
type_outs
.
append
(
"through_
untrac
ed"
)
elif
disconnected
:
type_outs
.
append
(
"disconnected"
)
else
:
...
...
@@ -2814,7 +2858,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
out_pos
+=
1
n_mitmot_inps
+=
1
undefined_msg
=
None
through_
shar
ed
=
False
through_
untrac
ed
=
False
disconnected
=
True
mitmot_inp_taps
[
idx
+
offset
]
.
append
(
0
)
for
tap
in
taps
:
...
...
@@ -2836,9 +2880,9 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
disconnected
&=
disconnected_dC_dinps_t
[
ins_pos
]
through_
shar
ed
=
any
(
through_
untrac
ed
=
any
(
_sh
in
graph_inputs
([
dC_dinps_t
[
ins_pos
]])
for
_sh
in
self
.
inner_
shared
(
self_inputs
)
for
_sh
in
self
.
inner_
untraced_sit_sot
(
self_inputs
)
)
n_mitmot_inps
+=
1
...
...
@@ -2847,8 +2891,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if
undefined_msg
:
type_outs
.
append
(
undefined_msg
)
elif
through_
shar
ed
:
type_outs
.
append
(
"through_
shar
ed"
)
elif
through_
untrac
ed
:
type_outs
.
append
(
"through_
untrac
ed"
)
elif
disconnected
:
type_outs
.
append
(
"disconnected"
)
else
:
...
...
@@ -2884,15 +2928,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else
:
inner_out_mitmot
.
append
(
dC_dinps_t
[
ins_pos
])
through_
shar
ed
=
any
(
through_
untrac
ed
=
any
(
_sh
in
graph_inputs
([
dC_dinps_t
[
ins_pos
]])
for
_sh
in
self
.
inner_
shared
(
self_inputs
)
for
_sh
in
self
.
inner_
untraced_sit_sot
(
self_inputs
)
)
if
isinstance
(
dC_dinps_t
[
ins_pos
]
.
type
,
NullType
):
type_outs
.
append
(
dC_dinps_t
[
ins_pos
]
.
type
.
why_null
)
elif
through_
shar
ed
:
type_outs
.
append
(
"through_
shar
ed"
)
elif
through_
untrac
ed
:
type_outs
.
append
(
"through_
untrac
ed"
)
elif
disconnected_dC_dinps_t
[
ins_pos
]:
type_outs
.
append
(
"disconnected"
)
else
:
...
...
@@ -2911,10 +2955,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_out_nitsot
=
dC_dinps_t
[:
info
.
n_seqs
]
inner_out_sitsot
=
dC_dinps_t
[
ins_pos
:]
for
_p
,
vl
in
enumerate
(
inner_out_sitsot
):
through_
shar
ed
=
False
for
_sh
in
self
.
inner_
shared
(
self_inputs
):
through_
untrac
ed
=
False
for
_sh
in
self
.
inner_
untraced_sit_sot
(
self_inputs
):
if
_sh
in
graph_inputs
([
vl
]):
through_
shar
ed
=
True
through_
untrac
ed
=
True
if
isinstance
(
vl
.
type
,
NullType
):
type_outs
.
append
(
vl
.
type
.
why_null
)
# Replace the inner output with a zero tensor of
...
...
@@ -2922,18 +2966,18 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_out_sitsot
[
_p
]
=
pt
.
zeros
(
diff_inputs
[
ins_pos
+
_p
]
.
shape
,
dtype
=
config
.
floatX
)
elif
through_
shar
ed
:
type_outs
.
append
(
"through_
shar
ed"
)
elif
through_
untrac
ed
:
type_outs
.
append
(
"through_
untrac
ed"
)
elif
disconnected_dC_dinps_t
[
_p
+
ins_pos
]:
type_outs
.
append
(
"disconnected"
)
else
:
type_outs
.
append
(
"connected"
)
for
_p
,
vl
in
enumerate
(
inner_out_nitsot
):
through_
shar
ed
=
False
for
_sh
in
self
.
inner_
shared
(
self_inputs
):
through_
untrac
ed
=
False
for
_sh
in
self
.
inner_
untraced_sit_sot
(
self_inputs
):
if
_sh
in
graph_inputs
([
vl
]):
through_
shar
ed
=
True
through_
untrac
ed
=
True
if
isinstance
(
vl
.
type
,
NullType
):
type_outs
.
append
(
vl
.
type
.
why_null
)
# Replace the inner output with a zero tensor of
...
...
@@ -2942,8 +2986,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
diff_inputs
[
_p
]
.
shape
,
dtype
=
config
.
floatX
)
if
through_
shar
ed
:
type_outs
.
append
(
"through_
shar
ed"
)
if
through_
untrac
ed
:
type_outs
.
append
(
"through_
untrac
ed"
)
elif
disconnected_dC_dinps_t
[
_p
]:
type_outs
.
append
(
"disconnected"
)
else
:
...
...
@@ -2983,7 +3027,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+
outer_inp_mitmot
+
outer_inp_sitsot
+
[
n_steps
if
info
.
as_while
else
inputs
[
0
]
for
_
in
range
(
n_nit_sot
)]
+
self
.
outer_
shared
(
inputs
)
+
self
.
outer_
untraced_sit_sot
(
inputs
)
+
self
.
outer_non_seqs
(
inputs
)
)
...
...
@@ -2991,7 +3035,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
inner_inp_seqs
+
inner_inp_mitmot
+
inner_inp_sitsot
+
self
.
inner_
shared
(
self_inputs
)
+
self
.
inner_
untraced_sit_sot
(
self_inputs
)
+
self
.
inner_non_seqs
(
self_inputs
)
)
inner_gfn_outs
=
inner_out_mitmot
+
inner_out_sitsot
+
inner_out_nitsot
...
...
@@ -3003,8 +3047,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mit_sot_in_slices
=
(),
sit_sot_in_slices
=
tuple
((
-
1
,)
for
k
in
range
(
n_sitsot_outs
)),
n_nit_sot
=
n_nit_sot
,
n_
shared
_outs
=
0
,
n_non_seqs
=
len
(
self
.
outer_
shared
(
inputs
))
n_
untraced_sit_sot
_outs
=
0
,
n_non_seqs
=
len
(
self
.
outer_
untraced_sit_sot
(
inputs
))
+
len
(
self
.
outer_non_seqs
(
inputs
)),
as_while
=
False
,
)
...
...
@@ -3047,10 +3091,10 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
gradients
.
append
(
x
[::
-
1
])
elif
t
==
"disconnected"
:
gradients
.
append
(
DisconnectedType
()())
elif
t
==
"through_
shar
ed"
:
elif
t
==
"through_
untrac
ed"
:
gradients
.
append
(
grad_undefined
(
self
,
p
+
1
,
inputs
[
p
+
1
],
"Depends on a
shar
ed variable"
self
,
p
+
1
,
inputs
[
p
+
1
],
"Depends on a
untrac
ed variable"
)
)
else
:
...
...
@@ -3075,13 +3119,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
gradients
.
append
(
x
[::
-
1
])
elif
t
==
"disconnected"
:
gradients
.
append
(
DisconnectedType
()())
elif
t
==
"through_
shar
ed"
:
elif
t
==
"through_
untrac
ed"
:
gradients
.
append
(
grad_undefined
(
self
,
p
+
1
+
info
.
n_seqs
,
inputs
[
p
+
1
+
info
.
n_seqs
],
"Depends on a
shar
ed variable"
,
"Depends on a
n untrac
ed variable"
,
)
)
else
:
...
...
@@ -3090,7 +3134,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
start
=
len
(
gradients
)
node
=
outs
[
0
]
.
owner
for
idx
in
range
(
info
.
n_
shared
_outs
):
for
idx
in
range
(
info
.
n_
untraced_sit_sot
_outs
):
disconnected
=
True
connected_flags
=
self
.
connection_pattern
(
node
)[
idx
+
start
]
for
dC_dout
,
connected
in
zip
(
dC_douts
,
connected_flags
,
strict
=
True
):
...
...
@@ -3116,13 +3160,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
gradients
.
append
(
x
[
-
1
])
elif
t
==
"disconnected"
:
gradients
.
append
(
DisconnectedType
()())
elif
t
==
"through_
shar
ed"
:
elif
t
==
"through_
untrac
ed"
:
gradients
.
append
(
grad_undefined
(
self
,
p
+
begin
+
1
,
inputs
[
p
+
begin
+
1
],
"Depends on a
shar
ed variable"
,
"Depends on a
untrac
ed variable"
,
)
)
else
:
...
...
@@ -3152,7 +3196,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
self_inputs
=
self
.
inner_inputs
rop_of_inputs
=
(
self_inputs
[:
info
.
n_seqs
+
self
.
n_outs
]
+
self_inputs
[
info
.
n_seqs
+
self
.
n_outs
+
info
.
n_
shared
_outs
:]
+
self_inputs
[
info
.
n_seqs
+
self
.
n_outs
+
info
.
n_
untraced_sit_sot
_outs
:]
)
self_outputs
=
self
.
inner_outputs
...
...
@@ -3162,8 +3206,8 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
rop_self_outputs
=
self_outputs
[:
-
1
]
else
:
rop_self_outputs
=
self_outputs
if
info
.
n_
shared
_outs
>
0
:
rop_self_outputs
=
rop_self_outputs
[:
-
info
.
n_
shared
_outs
]
if
info
.
n_
untraced_sit_sot
_outs
>
0
:
rop_self_outputs
=
rop_self_outputs
[:
-
info
.
n_
untraced_sit_sot
_outs
]
rop_outs
=
Rop
(
rop_self_outputs
,
rop_of_inputs
,
...
...
@@ -3247,13 +3291,13 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
scan_sit_sot
=
inputs
[
b
:
e
]
+
clean_eval_points
inner_sit_sot
=
self_inputs
[
ib
:
ie
]
+
inner_eval_points
[
ib
:
ie
]
#
Shar
ed outs ...
#
Untrac
ed outs ...
b
=
e
e
=
e
+
info
.
n_
shared
_outs
e
=
e
+
info
.
n_
untraced_sit_sot
_outs
ib
=
ie
ie
=
ie
+
info
.
n_
shared
_outs
scan_
shar
ed
=
inputs
[
b
:
e
]
inner_
shar
ed
=
self_inputs
[
ib
:
ie
]
ie
=
ie
+
info
.
n_
untraced_sit_sot
_outs
scan_
untrac
ed
=
inputs
[
b
:
e
]
inner_
untrac
ed
=
self_inputs
[
ib
:
ie
]
# NIT_SOT sequences
b
=
e
...
...
@@ -3268,7 +3312,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
else
:
clean_eval_points
.
append
(
inp
.
zeros_like
())
scan_other
=
inputs
[
e
:]
+
clean_eval_points
# inner_eval_points do not have entries for
shar
ed variables
# inner_eval_points do not have entries for
untrac
ed variables
inner_other
=
self_inputs
[
ie
:]
+
inner_eval_points
[
ib
:]
# Outputs
...
...
@@ -3287,15 +3331,15 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
e
=
e
+
info
.
n_nit_sot
inner_out_nit_sot
=
self_outputs
[
b
:
e
]
+
rop_outs
[
b
:
e
]
b
=
e
e
=
e
+
info
.
n_
shared
_outs
inner_out_
shar
ed
=
self_outputs
[
b
:
e
]
e
=
e
+
info
.
n_
untraced_sit_sot
_outs
inner_out_
untrac
ed
=
self_outputs
[
b
:
e
]
inner_ins
=
(
inner_seqs
+
inner_mit_mot
+
inner_mit_sot
+
inner_sit_sot
+
inner_
shar
ed
+
inner_
untrac
ed
+
inner_other
)
inner_outs
=
(
...
...
@@ -3303,7 +3347,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
+
inner_out_mit_sot
+
inner_out_sit_sot
+
inner_out_nit_sot
+
inner_out_
shar
ed
+
inner_out_
untrac
ed
)
if
info
.
as_while
:
...
...
@@ -3314,7 +3358,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
*
scan_mit_mot
,
*
scan_mit_sot
,
*
scan_sit_sot
,
*
scan_
shar
ed
,
*
scan_
untrac
ed
,
*
scan_nit_sot
,
*
scan_other
,
]
...
...
@@ -3326,7 +3370,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
mit_sot_in_slices
=
new_mit_sot_in_slices
,
sit_sot_in_slices
=
new_sit_sot_in_slices
,
n_nit_sot
=
info
.
n_nit_sot
*
2
,
n_
shared_outs
=
info
.
n_shared
_outs
,
n_
untraced_sit_sot_outs
=
info
.
n_untraced_sit_sot
_outs
,
n_non_seqs
=
len
(
inner_other
),
as_while
=
info
.
as_while
,
)
...
...
@@ -3358,7 +3402,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
b
=
e
+
info
.
n_nit_sot
e
=
e
+
info
.
n_nit_sot
*
2
final_outs
+=
outputs
[
b
:
e
]
final_outs
+=
[
None
]
*
info
.
n_
shared
_outs
final_outs
+=
[
None
]
*
info
.
n_
untraced_sit_sot
_outs
return
final_outs
...
...
pytensor/scan/rewriting.py
浏览文件 @
1d19c375
...
...
@@ -110,7 +110,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
sum
(
len
(
x
)
for
x
in
chain
(
op_info
.
mit_mot_in_slices
,
op_info
.
mit_sot_in_slices
))
)
st
+=
op_info
.
n_sit_sot
st
+=
op_info
.
n_
shared
_outs
st
+=
op_info
.
n_
untraced_sit_sot
_outs
op_ins
=
op
.
inner_inputs
op_outs
=
op
.
inner_outputs
...
...
@@ -126,7 +126,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
+
op_info
.
n_mit_sot
+
op_info
.
n_sit_sot
+
op_info
.
n_nit_sot
+
op_info
.
n_
shared
_outs
+
op_info
.
n_
untraced_sit_sot
_outs
+
1
)
outer_non_seqs
=
node
.
inputs
[
st
:]
...
...
@@ -983,7 +983,7 @@ class ScanInplaceOptimizer(GraphRewriter):
ls
=
op
.
outer_mitmot
(
node
.
inputs
)
ls
+=
op
.
outer_mitsot
(
node
.
inputs
)
ls
+=
op
.
outer_sitsot
(
node
.
inputs
)
ls_end
=
op
.
outer_
shared
(
node
.
inputs
)
ls_end
=
op
.
outer_
untraced_sit_sot
(
node
.
inputs
)
ls_end
+=
op
.
outer_nitsot
(
node
.
inputs
)
ls_end
+=
op
.
outer_non_seqs
(
node
.
inputs
)
...
...
@@ -1628,7 +1628,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
+
idx
+
op_info
.
n_seqs
+
1
+
op_info
.
n_
shared
_outs
+
op_info
.
n_
untraced_sit_sot
_outs
)
if
nw_inputs
[
pos
]
==
node
.
inputs
[
0
]:
nw_inputs
[
pos
]
=
1
if
required_orphan
else
val
...
...
@@ -1662,7 +1662,7 @@ def scan_save_mem_rewrite(fgraph, node, backend_supports_output_pre_allocation:
elif
(
idx
<
op_info
.
n_mit_sot
+
op_info
.
n_sit_sot
+
op_info
.
n_nit_sot
):
in_idx
=
offset
+
idx
+
op_info
.
n_
shared
_outs
in_idx
=
offset
+
idx
+
op_info
.
n_
untraced_sit_sot
_outs
if
nw_inputs
[
in_idx
]
==
node
.
inputs
[
0
]:
nw_inputs
[
in_idx
]
=
nw_steps
...
...
@@ -1886,8 +1886,8 @@ class ScanMerge(GraphRewriter):
for
idx
,
nd
in
enumerate
(
nodes
):
# Shared
inner_ins
[
idx
]
.
append
(
nd
.
op
.
inner_
shared
(
nd
.
op
.
inner_inputs
))
outer_ins
+=
nd
.
op
.
outer_
shared
(
nd
.
inputs
)
inner_ins
[
idx
]
.
append
(
nd
.
op
.
inner_
untraced_sit_sot
(
nd
.
op
.
inner_inputs
))
outer_ins
+=
nd
.
op
.
outer_
untraced_sit_sot
(
nd
.
inputs
)
for
idx
,
nd
in
enumerate
(
nodes
):
# NitSot
...
...
@@ -1897,8 +1897,10 @@ class ScanMerge(GraphRewriter):
for
idx
,
nd
in
enumerate
(
nodes
):
# Shared
outer_outs
+=
nd
.
op
.
outer_shared_outs
(
nd
.
outputs
)
inner_outs
[
idx
]
.
append
(
nd
.
op
.
inner_shared_outs
(
nd
.
op
.
inner_outputs
))
outer_outs
+=
nd
.
op
.
outer_untraced_sit_sot_outs
(
nd
.
outputs
)
inner_outs
[
idx
]
.
append
(
nd
.
op
.
inner_untraced_sit_sot_outs
(
nd
.
op
.
inner_outputs
)
)
n_non_seqs
=
0
for
idx
,
nd
in
enumerate
(
nodes
):
...
...
@@ -1978,7 +1980,9 @@ class ScanMerge(GraphRewriter):
mit_sot_in_slices
=
mit_sot_in_slices
,
sit_sot_in_slices
=
sit_sot_in_slices
,
n_nit_sot
=
sum
(
nd
.
op
.
info
.
n_nit_sot
for
nd
in
nodes
),
n_shared_outs
=
sum
(
nd
.
op
.
info
.
n_shared_outs
for
nd
in
nodes
),
n_untraced_sit_sot_outs
=
sum
(
nd
.
op
.
info
.
n_untraced_sit_sot_outs
for
nd
in
nodes
),
n_non_seqs
=
n_non_seqs
,
as_while
=
as_while
,
)
...
...
@@ -2360,7 +2364,7 @@ def scan_push_out_dot1(fgraph, node):
# When seq[t] is a vector/matrix and `value` is a matrix
# Note that this works when only you need X[-1] in the end
# and assumes dimshuffle are applied to vectors before calling dot
op
=
node
.
op
op
:
Scan
=
node
.
op
sitsot_ins
=
op
.
inner_sitsot
(
op
.
inner_inputs
)
sitsot_outs
=
op
.
inner_sitsot_outs
(
op
.
inner_outputs
)
outer_sitsot
=
op
.
outer_sitsot_outs
(
node
.
outputs
)
...
...
@@ -2416,9 +2420,13 @@ def scan_push_out_dot1(fgraph, node):
inner_sitsot_outs
=
op
.
inner_sitsot_outs
(
op
.
inner_outputs
)
outer_nitsot
=
op
.
outer_nitsot
(
node
.
inputs
)
inner_nitsot_outs
=
op
.
inner_nitsot_outs
(
op
.
inner_outputs
)
inner_shared
=
op
.
inner_shared
(
op
.
inner_inputs
)
outer_shared
=
op
.
outer_shared
(
node
.
inputs
)
inner_shared_outs
=
op
.
inner_shared_outs
(
op
.
inner_outputs
)
inner_untraced_sitsot
=
op
.
inner_untraced_sitsot
(
op
.
inner_inputs
)
outer_untraced_sitsot_outs
=
op
.
outer_untraced_sitsot_outs
(
node
.
inputs
)
inner_untraced_sitsot_outs
=
op
.
inner_untraced_sitsot_outs
(
op
.
inner_outputs
)
inner_non_seqs
=
op
.
inner_non_seqs
(
op
.
inner_inputs
)
outer_non_seqs
=
op
.
outer_non_seqs
(
node
.
inputs
)
...
...
@@ -2441,7 +2449,7 @@ def scan_push_out_dot1(fgraph, node):
+
inner_mitmot
+
inner_mitsot
+
inner_sitsot
+
inner_
shared
+
inner_
untraced_sitsot
+
inner_non_seqs
)
_new_inner_outs
=
(
...
...
@@ -2449,7 +2457,7 @@ def scan_push_out_dot1(fgraph, node):
+
inner_mitsot_outs
+
inner_sitsot_outs
+
inner_nitsot_outs
+
inner_
shared
_outs
+
inner_
untraced_sitsot
_outs
)
new_inner_inps
,
new_inner_outs
=
reconstruct_graph
(
_new_inner_inps
,
_new_inner_outs
...
...
@@ -2471,7 +2479,7 @@ def scan_push_out_dot1(fgraph, node):
*
outer_mitmot
,
*
outer_mitsot
,
*
outer_sitsot
,
*
outer_
shared
,
*
outer_
untraced_sitsot_outs
,
*
outer_nitsot
,
node
.
inputs
[
0
],
*
outer_non_seqs
,
...
...
pytensor/scan/utils.py
浏览文件 @
1d19c375
...
...
@@ -370,7 +370,9 @@ def scan_can_remove_outs(op, out_idxs):
out_ins
+=
[
op
.
inner_inputs
[
offset
:
offset
+
n_ins
]]
offset
+=
n_ins
out_ins
+=
[[]
for
k
in
range
(
op
.
info
.
n_nit_sot
)]
out_ins
+=
[[
op
.
inner_inputs
[
offset
+
k
]]
for
k
in
range
(
op
.
info
.
n_shared_outs
)]
out_ins
+=
[
[
op
.
inner_inputs
[
offset
+
k
]]
for
k
in
range
(
op
.
info
.
n_untraced_sit_sot_outs
)
]
added
=
True
out_idxs_mask
=
[
1
for
idx
in
out_idxs
]
...
...
@@ -409,7 +411,7 @@ def compress_outs(op, not_required, inputs):
mit_sot_in_slices
=
(),
sit_sot_in_slices
=
(),
n_nit_sot
=
0
,
n_
shared
_outs
=
0
,
n_
untraced_sit_sot
_outs
=
0
,
n_non_seqs
=
0
,
as_while
=
op_info
.
as_while
,
)
...
...
@@ -515,17 +517,19 @@ def compress_outs(op, not_required, inputs):
info
=
dataclasses
.
replace
(
info
,
n_nit_sot
=
info
.
n_nit_sot
+
1
)
op_outputs
+=
[
op
.
inner_outputs
[
o_offset
]]
o_offset
+=
1
nit_sot_ins
+=
[
inputs
[
ni_offset
+
idx
+
op_info
.
n_
shared
_outs
]]
nit_sot_ins
+=
[
inputs
[
ni_offset
+
idx
+
op_info
.
n_
untraced_sit_sot
_outs
]]
else
:
o_offset
+=
1
offset
+=
op_info
.
n_nit_sot
shared_ins
=
[]
for
idx
in
range
(
op_info
.
n_
shared
_outs
):
for
idx
in
range
(
op_info
.
n_
untraced_sit_sot
_outs
):
if
offset
+
idx
not
in
not_required
:
map_old_new
[
offset
+
idx
]
=
curr_pos
curr_pos
+=
1
info
=
dataclasses
.
replace
(
info
,
n_shared_outs
=
info
.
n_shared_outs
+
1
)
info
=
dataclasses
.
replace
(
info
,
n_untraced_sit_sot_outs
=
info
.
n_untraced_sit_sot_outs
+
1
)
op_outputs
+=
[
op
.
inner_outputs
[
o_offset
]]
o_offset
+=
1
op_inputs
+=
[
op
.
inner_inputs
[
i_offset
]]
...
...
@@ -539,7 +543,9 @@ def compress_outs(op, not_required, inputs):
# other stuff
op_inputs
+=
op
.
inner_inputs
[
i_offset
:]
info
=
dataclasses
.
replace
(
info
,
n_non_seqs
=
len
(
op
.
inner_inputs
[
i_offset
:]))
node_inputs
+=
inputs
[
ni_offset
+
op_info
.
n_shared_outs
+
op_info
.
n_nit_sot
:]
node_inputs
+=
inputs
[
ni_offset
+
op_info
.
n_untraced_sit_sot_outs
+
op_info
.
n_nit_sot
:
]
if
op_info
.
as_while
:
op_outputs
+=
[
op
.
inner_outputs
[
o_offset
]]
map_old_new
[
o_offset
]
=
len
(
op_outputs
)
-
1
...
...
@@ -658,11 +664,11 @@ class ScanArgs:
p
+=
n_sit_sot
q
+=
n_sit_sot
n_
shared_outs
=
info
.
n_shared
_outs
self
.
outer_in_shared
=
list
(
outer_inputs
[
p
:
p
+
n_
shared
_outs
])
self
.
inner_in_shared
=
list
(
inner_inputs
[
q
:
q
+
n_
shared
_outs
])
p
+=
n_
shared
_outs
q
+=
n_
shared
_outs
n_
untraced_sit_sot_outs
=
info
.
n_untraced_sit_sot
_outs
self
.
outer_in_shared
=
list
(
outer_inputs
[
p
:
p
+
n_
untraced_sit_sot
_outs
])
self
.
inner_in_shared
=
list
(
inner_inputs
[
q
:
q
+
n_
untraced_sit_sot
_outs
])
p
+=
n_
untraced_sit_sot
_outs
q
+=
n_
untraced_sit_sot
_outs
n_nit_sot
=
info
.
n_nit_sot
self
.
outer_in_nit_sot
=
list
(
outer_inputs
[
p
:
p
+
n_nit_sot
])
...
...
@@ -702,10 +708,10 @@ class ScanArgs:
p
+=
n_nit_sot
q
+=
n_nit_sot
self
.
outer_out_shared
=
list
(
outer_outputs
[
p
:
p
+
n_
shared
_outs
])
self
.
inner_out_shared
=
list
(
inner_outputs
[
q
:
q
+
n_
shared
_outs
])
p
+=
n_
shared
_outs
q
+=
n_
shared
_outs
self
.
outer_out_shared
=
list
(
outer_outputs
[
p
:
p
+
n_
untraced_sit_sot
_outs
])
self
.
inner_out_shared
=
list
(
inner_outputs
[
q
:
q
+
n_
untraced_sit_sot
_outs
])
p
+=
n_
untraced_sit_sot
_outs
q
+=
n_
untraced_sit_sot
_outs
assert
p
==
len
(
outer_outputs
)
assert
q
==
len
(
inner_outputs
)
...
...
@@ -816,7 +822,7 @@ class ScanArgs:
mit_sot_in_slices
=
tuple
(
tuple
(
v
)
for
v
in
self
.
mit_sot_in_slices
),
sit_sot_in_slices
=
((
-
1
,),)
*
len
(
self
.
inner_in_sit_sot
),
n_nit_sot
=
len
(
self
.
outer_in_nit_sot
),
n_
shared
_outs
=
len
(
self
.
outer_in_shared
),
n_
untraced_sit_sot
_outs
=
len
(
self
.
outer_in_shared
),
n_non_seqs
=
len
(
self
.
inner_in_non_seqs
),
as_while
=
self
.
as_while
,
)
...
...
tests/link/numba/test_scan.py
浏览文件 @
1d19c375
...
...
@@ -85,7 +85,7 @@ from tests.link.numba.test_basic import compare_numba_and_py
3
,
[],
[
np
.
array
([
0.50100236
,
2.16822932
,
1.36326596
])],
lambda
op
:
op
.
info
.
n_
shared
_outs
>
0
,
lambda
op
:
op
.
info
.
n_
untraced_sit_sot
_outs
>
0
,
),
# mit-sot (that's also a type of sit-sot)
(
...
...
tests/scan/test_basic.py
浏览文件 @
1d19c375
...
...
@@ -42,11 +42,13 @@ from pytensor.tensor.math import all as pt_all
from
pytensor.tensor.math
import
dot
,
exp
,
mean
,
sigmoid
,
tanh
from
pytensor.tensor.math
import
sum
as
pt_sum
from
pytensor.tensor.random
import
normal
from
pytensor.tensor.random.type
import
RandomGeneratorType
,
random_generator_type
from
pytensor.tensor.random.utils
import
RandomStream
from
pytensor.tensor.shape
import
Shape_i
,
reshape
,
specify_shape
from
pytensor.tensor.sharedvar
import
SharedVariable
from
pytensor.tensor.subtensor
import
Subtensor
from
pytensor.tensor.type
import
(
TensorType
,
dcol
,
dmatrix
,
dscalar
,
...
...
@@ -4007,7 +4009,7 @@ class TestExamples:
[{}],
[],
3
,
lambda
op
:
op
.
info
.
n_
shared
_outs
>
0
,
lambda
op
:
op
.
info
.
n_
untraced_sit_sot
_outs
>
0
,
),
# mit-sot (that's also a type of sit-sot)
(
...
...
@@ -4106,3 +4108,34 @@ def test_output_storage_reuse(linker_mode):
res
=
f_cvm
()
assert
np
.
array_equal
(
res
,
np
.
array
([
3
,
1
,
0
]))
def
test_rng_outputs_info
():
rng_init
=
random_generator_type
(
"rng"
)
rng_x0
,
x0
=
pt
.
random
.
normal
(
0
,
rng
=
rng_init
,
dtype
=
"float64"
)
.
owner
.
outputs
def
step
(
prev_x
,
prev_rng
):
next_rng
,
next_x
=
pt
.
random
.
normal
(
prev_x
,
rng
=
prev_rng
,
dtype
=
"float64"
)
.
owner
.
outputs
return
next_x
,
next_rng
[
xs
,
rng_final
],
updates
=
scan
(
fn
=
step
,
outputs_info
=
[
x0
,
rng_x0
],
n_steps
=
10
,
)
assert
isinstance
(
xs
.
type
,
TensorType
)
assert
isinstance
(
rng_final
.
type
,
RandomGeneratorType
)
assert
not
updates
fn
=
function
([
rng_init
],
[
xs
,
rng_final
])
xs_eval
,
rng_final_eval
=
fn
(
np
.
random
.
default_rng
(
0
))
rng_ref
=
np
.
random
.
default_rng
(
0
)
assert
not
random_generator_type
.
values_eq
(
rng_ref
,
rng_final_eval
)
xs_ref
=
[
rng_ref
.
normal
(
0
)]
for
i
in
range
(
10
):
xs_ref
.
append
(
rng_ref
.
normal
(
xs_ref
[
-
1
]))
assert
random_generator_type
.
values_eq
(
rng_ref
,
rng_final_eval
)
np
.
testing
.
assert_allclose
(
xs_eval
,
xs_ref
[
1
:])
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论