Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
f3a7d94f
提交
f3a7d94f
authored
11月 27, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
12月 05, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Scan dispatches: correct handling of signed mitmot taps
Unlike MIT-SOT and SIT-SOT these can be positive or negative, depending on the order of differentiation
上级
ebc0de09
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
107 行增加
和
45 行删除
+107
-45
scan.py
pytensor/link/jax/dispatch/scan.py
+5
-2
scan.py
pytensor/link/numba/dispatch/scan.py
+39
-43
op.py
pytensor/scan/op.py
+20
-0
test_scan.py
tests/link/jax/test_scan.py
+5
-0
test_scan.py
tests/link/numba/test_scan.py
+5
-0
test_basic.py
tests/scan/test_basic.py
+33
-0
没有找到文件。
pytensor/link/jax/dispatch/scan.py
浏览文件 @
f3a7d94f
...
@@ -90,7 +90,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
...
@@ -90,7 +90,7 @@ def jax_funcify_Scan(op: Scan, **kwargs):
chain
.
from_iterable
(
chain
.
from_iterable
(
buffer
[(
i
+
np
.
array
(
taps
))]
buffer
[(
i
+
np
.
array
(
taps
))]
for
buffer
,
taps
in
zip
(
for
buffer
,
taps
in
zip
(
inner_mit_mot
,
info
.
mit_mot_in_slices
,
strict
=
True
inner_mit_mot
,
info
.
normalized_
mit_mot_in_slices
,
strict
=
True
)
)
)
)
)
)
...
@@ -140,7 +140,10 @@ def jax_funcify_Scan(op: Scan, **kwargs):
...
@@ -140,7 +140,10 @@ def jax_funcify_Scan(op: Scan, **kwargs):
new_mit_mot
=
[
new_mit_mot
=
[
buffer
.
at
[
i
+
np
.
array
(
taps
)]
.
set
(
new_vals
)
buffer
.
at
[
i
+
np
.
array
(
taps
)]
.
set
(
new_vals
)
for
buffer
,
new_vals
,
taps
in
zip
(
for
buffer
,
new_vals
,
taps
in
zip
(
old_mit_mot
,
new_mit_mot_vals
,
info
.
mit_mot_out_slices
,
strict
=
True
old_mit_mot
,
new_mit_mot_vals
,
info
.
normalized_mit_mot_out_slices
,
strict
=
True
,
)
)
]
]
# Discard oldest MIT-SOT and append newest value
# Discard oldest MIT-SOT and append newest value
...
...
pytensor/link/numba/dispatch/scan.py
浏览文件 @
f3a7d94f
...
@@ -27,9 +27,8 @@ def idx_to_str(
...
@@ -27,9 +27,8 @@ def idx_to_str(
idx_symbol
:
str
=
"i"
,
idx_symbol
:
str
=
"i"
,
allow_scalar
=
False
,
allow_scalar
=
False
,
)
->
str
:
)
->
str
:
if
offset
<
0
:
assert
offset
>=
0
indices
=
f
"{idx_symbol} + {array_name}.shape[0] - {offset}"
if
offset
>
0
:
elif
offset
>
0
:
indices
=
f
"{idx_symbol} + {offset}"
indices
=
f
"{idx_symbol} + {offset}"
else
:
else
:
indices
=
idx_symbol
indices
=
idx_symbol
...
@@ -226,33 +225,16 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
...
@@ -226,33 +225,16 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
# storage array like a circular buffer, and that's why we need to track the
# storage array like a circular buffer, and that's why we need to track the
# storage size along with the taps length/indexing offset.
# storage size along with the taps length/indexing offset.
def
add_output_storage_post_proc_stmt
(
def
add_output_storage_post_proc_stmt
(
outer_in_name
:
str
,
tap_sizes
:
tuple
[
int
,
...
]
,
storage_size
:
str
outer_in_name
:
str
,
max_offset
:
int
,
storage_size
:
str
):
):
tap_size
=
max
(
tap_sizes
)
# Rotate the storage so that the last computed value is at the end of the storage array.
if
op
.
info
.
as_while
:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_proc_stmts
.
append
(
dedent
(
f
"""
if i + {tap_size} < {storage_size}:
{storage_size} = i + {tap_size}
{outer_in_name} = {outer_in_name}[:{storage_size}]
"""
)
.
strip
()
)
# Rotate the storage so that the last computed value is at the end of
# the storage array.
# This is needed when the output storage array does not have a length
# This is needed when the output storage array does not have a length
# equal to the number of taps plus `n_steps`.
# equal to the number of taps plus `n_steps`.
# If the storage size only allows one entry, there's nothing to rotate
output_storage_post_proc_stmts
.
append
(
output_storage_post_proc_stmts
.
append
(
dedent
(
dedent
(
f
"""
f
"""
if 1 < {storage_size} < (i + {
tap_size
}):
if 1 < {storage_size} < (i + {
max_offset
}):
{outer_in_name}_shift = (i + {
tap_size
})
%
({storage_size})
{outer_in_name}_shift = (i + {
max_offset
})
%
({storage_size})
if {outer_in_name}_shift > 0:
if {outer_in_name}_shift > 0:
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_left = {outer_in_name}[:{outer_in_name}_shift]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
{outer_in_name}_right = {outer_in_name}[{outer_in_name}_shift:]
...
@@ -261,6 +243,18 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
...
@@ -261,6 +243,18 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
)
.
strip
()
)
.
strip
()
)
)
if
op
.
info
.
as_while
:
# While loops need to truncate the output storage to a length given
# by the number of iterations performed.
output_storage_post_proc_stmts
.
append
(
dedent
(
f
"""
elif {storage_size} > (i + {max_offset}):
{outer_in_name} = {outer_in_name}[:i + {max_offset}]
"""
)
.
strip
()
)
# Special in-loop statements that create (nit-sot) storage arrays after a
# Special in-loop statements that create (nit-sot) storage arrays after a
# single iteration is performed. This is necessary because we don't know
# single iteration is performed. This is necessary because we don't know
# the exact shapes of the storage arrays that need to be allocated until
# the exact shapes of the storage arrays that need to be allocated until
...
@@ -288,12 +282,11 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
...
@@ -288,12 +282,11 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
storage_size_name
=
f
"{outer_in_name}_len"
storage_size_name
=
f
"{outer_in_name}_len"
storage_size_stmt
=
f
"{storage_size_name} = {outer_in_name}.shape[0]"
storage_size_stmt
=
f
"{storage_size_name} = {outer_in_name}.shape[0]"
input_taps
=
inner_in_names_to_input_taps
[
outer_in_name
]
input_taps
=
inner_in_names_to_input_taps
[
outer_in_name
]
tap_storage_size
=
-
min
(
input_taps
)
max_lookback_inp_tap
=
-
min
(
0
,
min
(
input_taps
)
)
assert
tap_storage_size
>=
0
assert
max_lookback_inp_tap
>=
0
for
in_tap
in
input_taps
:
for
in_tap
in
input_taps
:
tap_offset
=
in_tap
+
tap_storage_size
tap_offset
=
max_lookback_inp_tap
+
in_tap
assert
tap_offset
>=
0
is_vector
=
outer_in_var
.
ndim
==
1
is_vector
=
outer_in_var
.
ndim
==
1
add_inner_in_expr
(
add_inner_in_expr
(
outer_in_name
,
outer_in_name
,
...
@@ -302,22 +295,25 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
...
@@ -302,22 +295,25 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
vector_slice_opt
=
is_vector
,
vector_slice_opt
=
is_vector
,
)
)
output_taps
=
inner_in_names_to_output_taps
.
get
(
output_taps
=
inner_in_names_to_output_taps
.
get
(
outer_in_name
,
[
0
])
outer_in_name
,
[
tap_storage_size
]
for
out_tap
in
output_taps
:
)
tap_offset
=
max_lookback_inp_tap
+
out_tap
inner_out_to_outer_in_stmts
.
extend
(
assert
tap_offset
>=
0
idx_to_str
(
inner_out_to_outer_in_stmts
.
append
(
storage_name
,
idx_to_str
(
out_tap
,
storage_name
,
size
=
storage_size_name
,
tap_offset
,
allow_scalar
=
True
,
size
=
storage_size_name
,
allow_scalar
=
True
,
)
)
)
for
out_tap
in
output_taps
)
add_output_storage_post_proc_stmt
(
if
outer_in_name
not
in
outer_in_mit_mot_names
:
storage_name
,
output_taps
,
storage_size_name
# MIT-SOT and SIT-SOT may require buffer rolling/truncation after the main loop
)
max_offset_out_tap
=
max
(
output_taps
)
+
max_lookback_inp_tap
add_output_storage_post_proc_stmt
(
storage_name
,
max_offset_out_tap
,
storage_size_name
)
else
:
else
:
storage_size_stmt
=
""
storage_size_stmt
=
""
...
@@ -351,7 +347,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
...
@@ -351,7 +347,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
inner_out_to_outer_in_stmts
.
append
(
inner_out_to_outer_in_stmts
.
append
(
idx_to_str
(
storage_name
,
0
,
size
=
storage_size_name
,
allow_scalar
=
True
)
idx_to_str
(
storage_name
,
0
,
size
=
storage_size_name
,
allow_scalar
=
True
)
)
)
add_output_storage_post_proc_stmt
(
storage_name
,
(
0
,)
,
storage_size_name
)
add_output_storage_post_proc_stmt
(
storage_name
,
0
,
storage_size_name
)
# In case of nit-sots we are provided the length of the array in
# In case of nit-sots we are provided the length of the array in
# the iteration dimension instead of actual arrays, hence we
# the iteration dimension instead of actual arrays, hence we
...
...
pytensor/scan/op.py
浏览文件 @
f3a7d94f
...
@@ -288,6 +288,26 @@ class ScanInfo:
...
@@ -288,6 +288,26 @@ class ScanInfo:
+
self
.
n_untraced_sit_sot_outs
+
self
.
n_untraced_sit_sot_outs
)
)
@property
def
normalized_mit_mot_in_slices
(
self
)
->
tuple
[
tuple
[
int
,
...
],
...
]:
"""Return mit_mot_in slices normalized as an offset from the oldest tap"""
# TODO: Make this the canonical representation
res
=
[]
for
in_slice
in
self
.
mit_mot_in_slices
:
min_tap
=
-
(
min
(
0
,
min
(
in_slice
)))
res
.
append
(
tuple
(
tap
+
min_tap
for
tap
in
in_slice
))
return
tuple
(
res
)
@property
def
normalized_mit_mot_out_slices
(
self
)
->
tuple
[
tuple
[
int
,
...
],
...
]:
"""Return mit_mot_out slices normalized as an offset from the oldest tap"""
# TODO: Make this the canonical representation
res
=
[]
for
out_slice
in
self
.
mit_mot_out_slices
:
min_tap
=
-
(
min
(
0
,
min
(
out_slice
)))
res
.
append
(
tuple
(
tap
+
min_tap
for
tap
in
out_slice
))
return
tuple
(
res
)
TensorConstructorType
=
Callable
[
TensorConstructorType
=
Callable
[
[
Iterable
[
bool
|
int
|
None
],
str
|
np
.
generic
],
TensorType
[
Iterable
[
bool
|
int
|
None
],
str
|
np
.
generic
],
TensorType
...
...
tests/link/jax/test_scan.py
浏览文件 @
f3a7d94f
...
@@ -15,6 +15,7 @@ from pytensor.tensor import random
...
@@ -15,6 +15,7 @@ from pytensor.tensor import random
from
pytensor.tensor.math
import
gammaln
,
log
from
pytensor.tensor.math
import
gammaln
,
log
from
pytensor.tensor.type
import
dmatrix
,
dvector
,
matrix
,
scalar
,
vector
from
pytensor.tensor.type
import
dmatrix
,
dvector
,
matrix
,
scalar
,
vector
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.scan.test_basic
import
ScanCompatibilityTests
jax
=
pytest
.
importorskip
(
"jax"
)
jax
=
pytest
.
importorskip
(
"jax"
)
...
@@ -626,3 +627,7 @@ def test_scan_benchmark(model, mode, gradient_backend, benchmark):
...
@@ -626,3 +627,7 @@ def test_scan_benchmark(model, mode, gradient_backend, benchmark):
block_until_ready
(
*
test_input_vals
)
# Warmup
block_until_ready
(
*
test_input_vals
)
# Warmup
benchmark
.
pedantic
(
block_until_ready
,
test_input_vals
,
rounds
=
200
,
iterations
=
1
)
benchmark
.
pedantic
(
block_until_ready
,
test_input_vals
,
rounds
=
200
,
iterations
=
1
)
def
test_higher_order_derivatives
():
ScanCompatibilityTests
.
check_higher_order_derivative
(
mode
=
"JAX"
)
tests/link/numba/test_scan.py
浏览文件 @
f3a7d94f
...
@@ -16,6 +16,7 @@ from pytensor.tensor.elemwise import Elemwise
...
@@ -16,6 +16,7 @@ from pytensor.tensor.elemwise import Elemwise
from
pytensor.tensor.random.utils
import
RandomStream
from
pytensor.tensor.random.utils
import
RandomStream
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
from
tests.link.numba.test_basic
import
compare_numba_and_py
from
tests.link.numba.test_basic
import
compare_numba_and_py
from
tests.scan.test_basic
import
ScanCompatibilityTests
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -652,3 +653,7 @@ class TestScanMITSOTBuffer:
...
@@ -652,3 +653,7 @@ class TestScanMITSOTBuffer:
def
test_mit_sot_buffer_benchmark
(
self
,
constant_n_steps
,
n_steps_val
,
benchmark
):
def
test_mit_sot_buffer_benchmark
(
self
,
constant_n_steps
,
n_steps_val
,
benchmark
):
self
.
buffer_tester
(
constant_n_steps
,
n_steps_val
,
benchmark
=
benchmark
)
self
.
buffer_tester
(
constant_n_steps
,
n_steps_val
,
benchmark
=
benchmark
)
def
test_higher_order_derivatives
():
ScanCompatibilityTests
.
check_higher_order_derivative
(
mode
=
"NUMBA"
)
tests/scan/test_basic.py
浏览文件 @
f3a7d94f
...
@@ -4082,6 +4082,9 @@ class TestExamples:
...
@@ -4082,6 +4082,9 @@ class TestExamples:
# Also, the purpose of this test is not clear.
# Also, the purpose of this test is not clear.
self
.
_grad_mout_helper
(
1
,
None
)
self
.
_grad_mout_helper
(
1
,
None
)
def
test_higher_order_derivatives
(
self
):
ScanCompatibilityTests
.
check_higher_order_derivative
(
mode
=
None
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"fn, sequences, outputs_info, non_sequences, n_steps, op_check"
,
"fn, sequences, outputs_info, non_sequences, n_steps, op_check"
,
...
@@ -4398,3 +4401,33 @@ def test_scan_mode_compatibility(scan_mode):
...
@@ -4398,3 +4401,33 @@ def test_scan_mode_compatibility(scan_mode):
# Expected value computed by running correct Scan once
# Expected value computed by running correct Scan once
np
.
testing
.
assert_allclose
(
fn
(
*
numerical_inputs
),
[
44
,
38
])
np
.
testing
.
assert_allclose
(
fn
(
*
numerical_inputs
),
[
44
,
38
])
class
ScanCompatibilityTests
:
"""Collection of test of subtle required behaviors of Scan, that can be reused by different backends."""
@staticmethod
def
check_higher_order_derivative
(
mode
):
"""This tests different mit-mot taps signs"""
x
=
pt
.
dscalar
(
"x"
)
# xs[-1] is equivalent to x ** 16
xs
=
scan
(
fn
=
lambda
xtm1
:
xtm1
**
2
,
outputs_info
=
[
x
],
n_steps
=
4
,
return_updates
=
False
,
)
r
=
xs
[
-
1
]
g
=
grad
(
r
,
x
)
gg
=
grad
(
g
,
x
)
ggg
=
grad
(
gg
,
x
)
fn
=
function
([
x
],
[
r
,
g
,
gg
,
ggg
],
mode
=
mode
)
x_test
=
np
.
array
(
0.95
,
dtype
=
x
.
type
.
dtype
)
r_res
,
g_res
,
gg_res
,
_ggg_res
=
fn
(
x_test
)
np
.
testing
.
assert_allclose
(
r_res
,
x_test
**
16
)
np
.
testing
.
assert_allclose
(
g_res
,
16
*
x_test
**
15
)
np
.
testing
.
assert_allclose
(
gg_res
,
(
16
*
15
)
*
x_test
**
14
)
# FIXME: All implementations of Scan seem to get this one wrong!
# np.testing.assert_allclose(ggg_res, (16 * 15 * 14) * x_test**13)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论