Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
97797975
提交
97797975
authored
10月 13, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
10月 27, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Benchmark scan in JAX backend
Co-authored-by:
Jesse Grabowski
<
48652735+jessegrabowski@users.noreply.github.com
>
上级
545e58f7
隐藏空白字符变更
内嵌
并排
正在显示
1 个修改的文件
包含
239 行增加
和
92 行删除
+239
-92
test_scan.py
tests/link/jax/test_scan.py
+239
-92
没有找到文件。
tests/link/jax/test_scan.py
浏览文件 @
97797975
...
@@ -4,7 +4,7 @@ import numpy as np
...
@@ -4,7 +4,7 @@ import numpy as np
import
pytest
import
pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
function
,
shared
from
pytensor
import
function
,
ifelse
,
shared
from
pytensor.compile
import
get_mode
from
pytensor.compile
import
get_mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.scan
import
until
from
pytensor.scan
import
until
...
@@ -12,7 +12,7 @@ from pytensor.scan.basic import scan
...
@@ -12,7 +12,7 @@ from pytensor.scan.basic import scan
from
pytensor.scan.op
import
Scan
from
pytensor.scan.op
import
Scan
from
pytensor.tensor
import
random
from
pytensor.tensor
import
random
from
pytensor.tensor.math
import
gammaln
,
log
from
pytensor.tensor.math
import
gammaln
,
log
from
pytensor.tensor.type
import
dmatrix
,
dvector
,
lscalar
,
matrix
,
scalar
,
vector
from
pytensor.tensor.type
import
dmatrix
,
dvector
,
matrix
,
scalar
,
vector
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -189,96 +189,6 @@ def test_scan_while():
...
@@ -189,96 +189,6 @@ def test_scan_while():
compare_jax_and_py
([],
[
xs
],
[])
compare_jax_and_py
([],
[
xs
],
[])
def
test_scan_SEIR
():
"""Test a scan implementation of a SEIR model.
SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
"""
def
binomln
(
n
,
k
):
return
gammaln
(
n
+
1
)
-
gammaln
(
k
+
1
)
-
gammaln
(
n
-
k
+
1
)
def
binom_log_prob
(
n
,
p
,
value
):
return
binomln
(
n
,
value
)
+
value
*
log
(
p
)
+
(
n
-
value
)
*
log
(
1
-
p
)
# sequences
at_C
=
vector
(
"C_t"
,
dtype
=
"int32"
,
shape
=
(
8
,))
at_D
=
vector
(
"D_t"
,
dtype
=
"int32"
,
shape
=
(
8
,))
# outputs_info (initial conditions)
st0
=
lscalar
(
"s_t0"
)
et0
=
lscalar
(
"e_t0"
)
it0
=
lscalar
(
"i_t0"
)
logp_c
=
scalar
(
"logp_c"
)
logp_d
=
scalar
(
"logp_d"
)
# non_sequences
beta
=
scalar
(
"beta"
)
gamma
=
scalar
(
"gamma"
)
delta
=
scalar
(
"delta"
)
# TODO: Use random streams when their JAX conversions are implemented.
# trng = pytensor.tensor.random.RandomStream(1234)
def
seir_one_step
(
ct0
,
dt0
,
st0
,
et0
,
it0
,
logp_c
,
logp_d
,
beta
,
gamma
,
delta
):
# bt0 = trng.binomial(n=st0, p=beta)
bt0
=
st0
*
beta
bt0
=
bt0
.
astype
(
st0
.
dtype
)
logp_c1
=
binom_log_prob
(
et0
,
gamma
,
ct0
)
.
astype
(
logp_c
.
dtype
)
logp_d1
=
binom_log_prob
(
it0
,
delta
,
dt0
)
.
astype
(
logp_d
.
dtype
)
st1
=
st0
-
bt0
et1
=
et0
+
bt0
-
ct0
it1
=
it0
+
ct0
-
dt0
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
),
_
=
scan
(
fn
=
seir_one_step
,
sequences
=
[
at_C
,
at_D
],
outputs_info
=
[
st0
,
et0
,
it0
,
logp_c
,
logp_d
],
non_sequences
=
[
beta
,
gamma
,
delta
],
)
st
.
name
=
"S_t"
et
.
name
=
"E_t"
it
.
name
=
"I_t"
logp_c_all
.
name
=
"C_t_logp"
logp_d_all
.
name
=
"D_t_logp"
s0
,
e0
,
i0
=
100
,
50
,
25
logp_c0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
logp_d0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
beta_val
,
gamma_val
,
delta_val
=
(
np
.
array
(
val
,
dtype
=
config
.
floatX
)
for
val
in
[
0.277792
,
0.135330
,
0.108753
]
)
C
=
np
.
array
([
3
,
5
,
8
,
13
,
21
,
26
,
10
,
3
],
dtype
=
np
.
int32
)
D
=
np
.
array
([
1
,
2
,
3
,
7
,
9
,
11
,
5
,
1
],
dtype
=
np
.
int32
)
test_input_vals
=
[
C
,
D
,
s0
,
e0
,
i0
,
logp_c0
,
logp_d0
,
beta_val
,
gamma_val
,
delta_val
,
]
compare_jax_and_py
(
[
at_C
,
at_D
,
st0
,
et0
,
it0
,
logp_c
,
logp_d
,
beta
,
gamma
,
delta
],
[
st
,
et
,
it
,
logp_c_all
,
logp_d_all
],
test_input_vals
,
jax_mode
=
"JAX"
,
)
def
test_scan_mitsot_with_nonseq
():
def
test_scan_mitsot_with_nonseq
():
a_pt
=
scalar
(
"a"
)
a_pt
=
scalar
(
"a"
)
...
@@ -420,3 +330,240 @@ def test_dynamic_sequence_length():
...
@@ -420,3 +330,240 @@ def test_dynamic_sequence_length():
assert
sum
(
isinstance
(
node
.
op
,
Scan
)
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
Scan
)
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
)
==
1
np
.
testing
.
assert_allclose
(
f
([]),
[])
np
.
testing
.
assert_allclose
(
f
([]),
[])
np
.
testing
.
assert_allclose
(
f
([
1
,
2
,
3
]),
np
.
array
([
2
,
3
,
4
]))
np
.
testing
.
assert_allclose
(
f
([
1
,
2
,
3
]),
np
.
array
([
2
,
3
,
4
]))
def
SEIR_model_logp
():
"""Setup a Scan implementation of a SEIR model.
SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
"""
def
binomln
(
n
,
k
):
return
gammaln
(
n
+
1
)
-
gammaln
(
k
+
1
)
-
gammaln
(
n
-
k
+
1
)
def
binom_log_prob
(
n
,
p
,
value
):
return
binomln
(
n
,
value
)
+
value
*
log
(
p
)
+
(
n
-
value
)
*
log
(
1
-
p
)
# sequences
C_t
=
vector
(
"C_t"
,
dtype
=
"int32"
,
shape
=
(
1200
,))
D_t
=
vector
(
"D_t"
,
dtype
=
"int32"
,
shape
=
(
1200
,))
# outputs_info (initial conditions)
st0
=
scalar
(
"s_t0"
)
et0
=
scalar
(
"e_t0"
)
it0
=
scalar
(
"i_t0"
)
# non_sequences
beta
=
scalar
(
"beta"
)
gamma
=
scalar
(
"gamma"
)
delta
=
scalar
(
"delta"
)
def
seir_one_step
(
ct0
,
dt0
,
st0
,
et0
,
it0
,
beta
,
gamma
,
delta
):
# bt0 = trng.binomial(n=st0, p=beta)
bt0
=
st0
*
beta
bt0
=
bt0
.
astype
(
st0
.
dtype
)
logp_c1
=
binom_log_prob
(
et0
,
gamma
,
ct0
)
logp_d1
=
binom_log_prob
(
it0
,
delta
,
dt0
)
st1
=
st0
-
bt0
et1
=
et0
+
bt0
-
ct0
it1
=
it0
+
ct0
-
dt0
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
),
_
=
scan
(
fn
=
seir_one_step
,
sequences
=
[
C_t
,
D_t
],
outputs_info
=
[
st0
,
et0
,
it0
,
None
,
None
],
non_sequences
=
[
beta
,
gamma
,
delta
],
)
st
.
name
=
"S_t"
et
.
name
=
"E_t"
it
.
name
=
"I_t"
logp_c_all
.
name
=
"C_t_logp"
logp_d_all
.
name
=
"D_t_logp"
st0_val
,
et0_val
,
it0_val
=
np
.
array
(
100.0
),
np
.
array
(
50.0
),
np
.
array
(
25.0
)
beta_val
,
gamma_val
,
delta_val
=
(
np
.
array
(
0.277792
),
np
.
array
(
0.135330
),
np
.
array
(
0.108753
),
)
C_t_val
=
np
.
array
([
3
,
5
,
8
,
13
,
21
,
26
,
10
,
3
]
*
150
,
dtype
=
np
.
int32
)
D_t_val
=
np
.
array
([
1
,
2
,
3
,
7
,
9
,
11
,
5
,
1
]
*
150
,
dtype
=
np
.
int32
)
assert
C_t_val
.
shape
==
D_t_val
.
shape
==
C_t
.
type
.
shape
==
D_t
.
type
.
shape
test_input_vals
=
[
C_t_val
,
D_t_val
,
st0_val
,
et0_val
,
it0_val
,
beta_val
,
gamma_val
,
delta_val
,
]
loss_graph
=
logp_c_all
.
sum
()
+
logp_d_all
.
sum
()
return
dict
(
graph_inputs
=
[
C_t
,
D_t
,
st0
,
et0
,
it0
,
beta
,
gamma
,
delta
],
differentiable_vars
=
[
st0
,
et0
,
it0
,
beta
,
gamma
,
delta
],
test_input_vals
=
test_input_vals
,
loss_graph
=
loss_graph
,
)
def
cyclical_reduction
():
"""Setup a Scan implementation of the cyclical reduction algorithm.
This solves the matrix equation A @ X @ X + B @ X + C = 0 for X
Adapted from https://github.com/jessegrabowski/gEconpy/blob/da495b22ac383cb6cb5dec15f305506aebef7302/gEconpy/solvers/cycle_reduction.py#L187
"""
def
stabilize
(
x
,
jitter
=
1e-16
):
return
x
+
jitter
*
pt
.
eye
(
x
.
shape
[
0
])
def
step
(
A0
,
A1
,
A2
,
A1_hat
,
norm
,
step_num
,
tol
):
def
cycle_step
(
A0
,
A1
,
A2
,
A1_hat
,
_norm
,
step_num
):
tmp
=
pt
.
dot
(
pt
.
vertical_stack
(
A0
,
A2
),
pt
.
linalg
.
solve
(
stabilize
(
A1
),
pt
.
horizontal_stack
(
A0
,
A2
),
assume_a
=
"gen"
,
check_finite
=
False
,
),
)
n
=
A0
.
shape
[
0
]
idx_0
=
pt
.
arange
(
n
)
idx_1
=
idx_0
+
n
A1
=
A1
-
tmp
[
idx_0
,
:][:,
idx_1
]
-
tmp
[
idx_1
,
:][:,
idx_0
]
A0
=
-
tmp
[
idx_0
,
:][:,
idx_0
]
A2
=
-
tmp
[
idx_1
,
:][:,
idx_1
]
A1_hat
=
A1_hat
-
tmp
[
idx_1
,
:][:,
idx_0
]
A0_L1_norm
=
pt
.
linalg
.
norm
(
A0
,
ord
=
1
)
return
A0
,
A1
,
A2
,
A1_hat
,
A0_L1_norm
,
step_num
+
1
return
ifelse
(
norm
<
tol
,
(
A0
,
A1
,
A2
,
A1_hat
,
norm
,
step_num
),
cycle_step
(
A0
,
A1
,
A2
,
A1_hat
,
norm
,
step_num
),
)
A
=
pt
.
matrix
(
"A"
,
shape
=
(
20
,
20
))
B
=
pt
.
matrix
(
"B"
,
shape
=
(
20
,
20
))
C
=
pt
.
matrix
(
"C"
,
shape
=
(
20
,
20
))
norm
=
np
.
array
(
1e9
,
dtype
=
"float64"
)
step_num
=
pt
.
zeros
((),
dtype
=
"int32"
)
max_iter
=
100
tol
=
1e-7
(
*
_
,
A1_hat
,
norm
,
_n_steps
),
_
=
scan
(
step
,
outputs_info
=
[
A
,
B
,
C
,
B
,
norm
,
step_num
],
non_sequences
=
[
tol
],
n_steps
=
max_iter
,
)
A1_hat
=
A1_hat
[
-
1
]
T
=
-
pt
.
linalg
.
solve
(
stabilize
(
A1_hat
),
A
,
assume_a
=
"gen"
,
check_finite
=
False
)
rng
=
np
.
random
.
default_rng
(
sum
(
map
(
ord
,
"cycle_reduction"
)))
n
=
A
.
type
.
shape
[
0
]
A_test
=
rng
.
standard_normal
(
size
=
(
n
,
n
))
C_test
=
rng
.
standard_normal
(
size
=
(
n
,
n
))
# B must be invertible, so we make it symmetric positive-definite
B_rand
=
rng
.
standard_normal
(
size
=
(
n
,
n
))
B_test
=
B_rand
@
B_rand
.
T
+
np
.
eye
(
n
)
*
1e-3
return
dict
(
graph_inputs
=
[
A
,
B
,
C
],
differentiable_vars
=
[
A
,
B
,
C
],
test_input_vals
=
[
A_test
,
B_test
,
C_test
],
loss_graph
=
pt
.
sum
(
T
),
)
@pytest.mark.parametrize
(
"gradient_backend"
,
[
"PYTENSOR"
,
"JAX"
])
@pytest.mark.parametrize
(
"mode"
,
(
"0forward"
,
"1backward"
,
"2both"
))
@pytest.mark.parametrize
(
"model"
,
[
cyclical_reduction
,
SEIR_model_logp
])
def
test_scan_benchmark
(
model
,
mode
,
gradient_backend
,
benchmark
):
if
gradient_backend
==
"PYTENSOR"
and
mode
in
(
"1backward"
,
"2both"
):
pytest
.
skip
(
"PYTENSOR backend does not support backward mode yet"
)
model_dict
=
model
()
graph_inputs
=
model_dict
[
"graph_inputs"
]
differentiable_vars
=
model_dict
[
"differentiable_vars"
]
loss_graph
=
model_dict
[
"loss_graph"
]
test_input_vals
=
model_dict
[
"test_input_vals"
]
if
gradient_backend
==
"PYTENSOR"
:
backward_loss
=
pt
.
grad
(
loss_graph
,
wrt
=
differentiable_vars
,
)
match
mode
:
# TODO: Restore original test separately
case
"0forward"
:
graph_outputs
=
[
loss_graph
]
case
"1backward"
:
graph_outputs
=
backward_loss
case
"2both"
:
graph_outputs
=
[
loss_graph
,
*
backward_loss
]
case
_
:
raise
ValueError
(
f
"Unknown mode: {mode}"
)
jax_fn
,
_
=
compare_jax_and_py
(
graph_inputs
,
graph_outputs
,
test_input_vals
,
jax_mode
=
"JAX"
,
)
jax_fn
.
trust_input
=
True
else
:
# gradient_backend == "JAX"
import
jax
loss_fn_tuple
=
function
(
graph_inputs
,
loss_graph
,
mode
=
"JAX"
)
.
vm
.
jit_fn
def
loss_fn
(
*
args
):
return
loss_fn_tuple
(
*
args
)[
0
]
match
mode
:
case
"0forward"
:
jax_fn
=
jax
.
jit
(
loss_fn_tuple
)
case
"1backward"
:
jax_fn
=
jax
.
jit
(
jax
.
grad
(
loss_fn
,
argnums
=
tuple
(
range
(
len
(
graph_inputs
))[
2
:]))
)
case
"2both"
:
value_and_grad_fn
=
jax
.
value_and_grad
(
loss_fn
,
argnums
=
tuple
(
range
(
len
(
graph_inputs
))[
2
:])
)
@jax.jit
def
jax_fn
(
*
args
):
loss
,
grads
=
value_and_grad_fn
(
*
args
)
return
loss
,
*
grads
case
_
:
raise
ValueError
(
f
"Unknown mode: {mode}"
)
def
block_until_ready
(
*
inputs
,
jax_fn
=
jax_fn
):
return
[
o
.
block_until_ready
()
for
o
in
jax_fn
(
*
inputs
)]
block_until_ready
(
*
test_input_vals
)
# Warmup
benchmark
.
pedantic
(
block_until_ready
,
test_input_vals
,
rounds
=
200
,
iterations
=
1
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论