Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
14e6c781
提交
14e6c781
authored
10月 10, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
10月 27, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Reimplement JAX Scan dispatcher with MIT-MOT support
Co-authored-by:
Jesse Grabowski
<
48652735+jessegrabowski@users.noreply.github.com
>
上级
97797975
全部展开
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
62 行增加
和
14 行删除
+62
-14
scan.py
pytensor/link/jax/dispatch/scan.py
+0
-0
op.py
pytensor/scan/op.py
+11
-0
test_scan.py
tests/link/jax/test_scan.py
+51
-14
没有找到文件。
pytensor/link/jax/dispatch/scan.py
浏览文件 @
14e6c781
差异被折叠。
点击展开。
pytensor/scan/op.py
浏览文件 @
14e6c781
...
@@ -307,6 +307,17 @@ class ScanMethodsMixin:
...
@@ -307,6 +307,17 @@ class ScanMethodsMixin:
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
info
.
mit_mot_out_slices
)
n_taps
=
sum
(
len
(
x
)
for
x
in
self
.
info
.
mit_mot_out_slices
)
return
list_outputs
[:
n_taps
]
return
list_outputs
[:
n_taps
]
def
inner_mitmot_outs_grouped
(
self
,
list_outputs
):
# Like inner_mitmot_outs but returns a list of lists, one per mitmot
# Instead of a flat list
n_taps
=
[
len
(
x
)
for
x
in
self
.
info
.
mit_mot_out_slices
]
grouped_outs
=
[]
offset
=
0
for
nt
in
n_taps
:
grouped_outs
.
append
(
list_outputs
[
offset
:
offset
+
nt
])
offset
+=
nt
return
grouped_outs
def
outer_mitmot_outs
(
self
,
list_outputs
):
def
outer_mitmot_outs
(
self
,
list_outputs
):
return
list_outputs
[:
self
.
info
.
n_mit_mot
]
return
list_outputs
[:
self
.
info
.
n_mit_mot
]
...
...
tests/link/jax/test_scan.py
浏览文件 @
14e6c781
...
@@ -7,6 +7,8 @@ import pytensor.tensor as pt
...
@@ -7,6 +7,8 @@ import pytensor.tensor as pt
from
pytensor
import
function
,
ifelse
,
shared
from
pytensor
import
function
,
ifelse
,
shared
from
pytensor.compile
import
get_mode
from
pytensor.compile
import
get_mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph
import
Apply
,
Op
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.scan
import
until
from
pytensor.scan
import
until
from
pytensor.scan.basic
import
scan
from
pytensor.scan.basic
import
scan
from
pytensor.scan.op
import
Scan
from
pytensor.scan.op
import
Scan
...
@@ -98,16 +100,26 @@ def test_scan_nit_sot(view):
...
@@ -98,16 +100,26 @@ def test_scan_nit_sot(view):
assert
len
(
scan_nodes
)
==
1
assert
len
(
scan_nodes
)
==
1
@pytest.mark.xfail
(
raises
=
NotImplementedError
)
def
test_scan_mit_mot
():
def
test_scan_mit_mot
():
xs
=
pt
.
vector
(
"xs"
,
shape
=
(
10
,))
def
step
(
xtm1
,
ytm3
,
ytm1
,
rho
):
ys
,
_
=
scan
(
return
(
xtm1
+
ytm1
)
*
rho
,
ytm3
*
(
1
-
rho
)
+
ytm1
*
rho
lambda
xtm2
,
xtm1
:
(
xtm2
+
xtm1
),
outputs_info
=
[{
"initial"
:
xs
,
"taps"
:
[
-
2
,
-
1
]}],
rho
=
pt
.
scalar
(
"rho"
,
dtype
=
"float64"
)
x0
=
pt
.
vector
(
"xs"
,
shape
=
(
2
,))
y0
=
pt
.
vector
(
"ys"
,
shape
=
(
3
,))
[
outs
,
_
],
_
=
scan
(
step
,
outputs_info
=
[
x0
,
{
"initial"
:
y0
,
"taps"
:
[
-
3
,
-
1
]}],
non_sequences
=
[
rho
],
n_steps
=
10
,
n_steps
=
10
,
)
)
grads_wrt_xs
=
pt
.
grad
(
ys
.
sum
(),
wrt
=
xs
)
grads
=
pt
.
grad
(
outs
.
sum
(),
wrt
=
[
x0
,
y0
,
rho
])
compare_jax_and_py
([
xs
],
[
grads_wrt_xs
],
[
np
.
arange
(
10
)])
compare_jax_and_py
(
[
x0
,
y0
,
rho
],
grads
,
[
np
.
arange
(
2
),
np
.
array
([
0.5
,
0.5
,
0.5
]),
np
.
array
(
0.95
)],
jax_mode
=
get_mode
(
"JAX"
),
)
def
test_scan_update
():
def
test_scan_update
():
...
@@ -323,13 +335,41 @@ def test_default_mode_excludes_incompatible_rewrites():
...
@@ -323,13 +335,41 @@ def test_default_mode_excludes_incompatible_rewrites():
def
test_dynamic_sequence_length
():
def
test_dynamic_sequence_length
():
x
=
pt
.
tensor
(
"x"
,
shape
=
(
None
,))
class
IncWithoutStaticShape
(
Op
):
out
,
_
=
scan
(
lambda
x
:
x
+
1
,
sequences
=
[
x
])
def
make_node
(
self
,
x
):
x
=
pt
.
as_tensor_variable
(
x
)
return
Apply
(
self
,
[
x
],
[
pt
.
tensor
(
shape
=
(
None
,)
*
x
.
type
.
ndim
)])
def
perform
(
self
,
node
,
inputs
,
outputs
):
outputs
[
0
][
0
]
=
inputs
[
0
]
+
1
@jax_funcify.register
(
IncWithoutStaticShape
)
def
_
(
op
,
**
kwargs
):
return
lambda
x
:
x
+
1
inc_without_static_shape
=
IncWithoutStaticShape
()
x
=
pt
.
tensor
(
"x"
,
shape
=
(
None
,
3
))
out
,
_
=
scan
(
lambda
x
:
inc_without_static_shape
(
x
),
outputs_info
=
[
None
],
sequences
=
[
x
]
)
f
=
function
([
x
],
out
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
f
=
function
([
x
],
out
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
assert
sum
(
isinstance
(
node
.
op
,
Scan
)
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
Scan
)
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
)
==
1
np
.
testing
.
assert_allclose
(
f
([]),
[])
np
.
testing
.
assert_allclose
(
f
([[
1
,
2
,
3
]]),
np
.
array
([[
2
,
3
,
4
]]))
np
.
testing
.
assert_allclose
(
f
([
1
,
2
,
3
]),
np
.
array
([
2
,
3
,
4
]))
with
pytest
.
raises
(
ValueError
):
f
(
np
.
zeros
((
0
,
3
)))
# But should be fine with static shape
out2
,
_
=
scan
(
lambda
x
:
pt
.
specify_shape
(
inc_without_static_shape
(
x
),
x
.
shape
),
outputs_info
=
[
None
],
sequences
=
[
x
],
)
f2
=
function
([
x
],
out2
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
np
.
testing
.
assert_allclose
(
f2
([[
1
,
2
,
3
]]),
np
.
array
([[
2
,
3
,
4
]]))
np
.
testing
.
assert_allclose
(
f2
(
np
.
zeros
((
0
,
3
))),
np
.
empty
((
0
,
3
)))
def
SEIR_model_logp
():
def
SEIR_model_logp
():
...
@@ -499,9 +539,6 @@ def cyclical_reduction():
...
@@ -499,9 +539,6 @@ def cyclical_reduction():
@pytest.mark.parametrize
(
"mode"
,
(
"0forward"
,
"1backward"
,
"2both"
))
@pytest.mark.parametrize
(
"mode"
,
(
"0forward"
,
"1backward"
,
"2both"
))
@pytest.mark.parametrize
(
"model"
,
[
cyclical_reduction
,
SEIR_model_logp
])
@pytest.mark.parametrize
(
"model"
,
[
cyclical_reduction
,
SEIR_model_logp
])
def
test_scan_benchmark
(
model
,
mode
,
gradient_backend
,
benchmark
):
def
test_scan_benchmark
(
model
,
mode
,
gradient_backend
,
benchmark
):
if
gradient_backend
==
"PYTENSOR"
and
mode
in
(
"1backward"
,
"2both"
):
pytest
.
skip
(
"PYTENSOR backend does not support backward mode yet"
)
model_dict
=
model
()
model_dict
=
model
()
graph_inputs
=
model_dict
[
"graph_inputs"
]
graph_inputs
=
model_dict
[
"graph_inputs"
]
differentiable_vars
=
model_dict
[
"differentiable_vars"
]
differentiable_vars
=
model_dict
[
"differentiable_vars"
]
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论