Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
abedb7fb
提交
abedb7fb
authored
10月 31, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
11月 08, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Start using new API in tests that don't involve shared updates
上级
78293400
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
73 行增加
和
35 行删除
+73
-35
test_scan.py
tests/link/jax/test_scan.py
+40
-20
test_scan.py
tests/link/numba/test_scan.py
+29
-13
test_basic.py
tests/scan/test_basic.py
+0
-0
test_rewriting.py
tests/scan/test_rewriting.py
+0
-0
test_rewriting.py
tests/tensor/linalg/test_rewriting.py
+2
-1
test_blockwise.py
tests/tensor/test_blockwise.py
+2
-1
没有找到文件。
tests/link/jax/test_scan.py
浏览文件 @
abedb7fb
...
@@ -23,10 +23,11 @@ jax = pytest.importorskip("jax")
...
@@ -23,10 +23,11 @@ jax = pytest.importorskip("jax")
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
2
,
None
,
None
)])
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
2
,
None
,
None
)])
def
test_scan_sit_sot
(
view
):
def
test_scan_sit_sot
(
view
):
x0
=
pt
.
scalar
(
"x0"
,
dtype
=
"float64"
)
x0
=
pt
.
scalar
(
"x0"
,
dtype
=
"float64"
)
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm1
:
xtm1
+
1
,
lambda
xtm1
:
xtm1
+
1
,
outputs_info
=
[
x0
],
outputs_info
=
[
x0
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
if
view
:
if
view
:
xs
=
xs
[
view
]
xs
=
xs
[
view
]
...
@@ -37,10 +38,11 @@ def test_scan_sit_sot(view):
...
@@ -37,10 +38,11 @@ def test_scan_sit_sot(view):
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
def
test_scan_mit_sot
(
view
):
def
test_scan_mit_sot
(
view
):
x0
=
pt
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
3
,))
x0
=
pt
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
3
,))
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm3
,
xtm1
:
xtm3
+
xtm1
+
1
,
lambda
xtm3
,
xtm1
:
xtm3
+
xtm1
+
1
,
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
if
view
:
if
view
:
xs
=
xs
[
view
]
xs
=
xs
[
view
]
...
@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y):
...
@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y):
def
step
(
xtm3
,
xtm1
,
ytm4
,
ytm2
):
def
step
(
xtm3
,
xtm1
,
ytm4
,
ytm2
):
return
xtm3
+
ytm4
+
1
,
xtm1
+
ytm2
+
2
return
xtm3
+
ytm4
+
1
,
xtm1
+
ytm2
+
2
[
xs
,
ys
]
,
_
=
scan
(
[
xs
,
ys
]
=
scan
(
fn
=
step
,
fn
=
step
,
outputs_info
=
[
outputs_info
=
[
{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]},
{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]},
{
"initial"
:
y0
,
"taps"
:
[
-
4
,
-
2
]},
{
"initial"
:
y0
,
"taps"
:
[
-
4
,
-
2
]},
],
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
if
view_x
:
if
view_x
:
xs
=
xs
[
view_x
]
xs
=
xs
[
view_x
]
...
@@ -80,10 +83,8 @@ def test_scan_nit_sot(view):
...
@@ -80,10 +83,8 @@ def test_scan_nit_sot(view):
xs
=
pt
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
10
,))
xs
=
pt
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
10
,))
ys
,
_
=
scan
(
ys
=
scan
(
lambda
x
:
pt
.
exp
(
x
),
lambda
x
:
pt
.
exp
(
x
),
outputs_info
=
[
None
],
sequences
=
[
xs
],
return_updates
=
False
outputs_info
=
[
None
],
sequences
=
[
xs
],
)
)
if
view
:
if
view
:
ys
=
ys
[
view
]
ys
=
ys
[
view
]
...
@@ -106,11 +107,12 @@ def test_scan_mit_mot():
...
@@ -106,11 +107,12 @@ def test_scan_mit_mot():
rho
=
pt
.
scalar
(
"rho"
,
dtype
=
"float64"
)
rho
=
pt
.
scalar
(
"rho"
,
dtype
=
"float64"
)
x0
=
pt
.
vector
(
"xs"
,
shape
=
(
2
,))
x0
=
pt
.
vector
(
"xs"
,
shape
=
(
2
,))
y0
=
pt
.
vector
(
"ys"
,
shape
=
(
3
,))
y0
=
pt
.
vector
(
"ys"
,
shape
=
(
3
,))
[
outs
,
_
]
,
_
=
scan
(
[
outs
,
_
]
=
scan
(
step
,
step
,
outputs_info
=
[
x0
,
{
"initial"
:
y0
,
"taps"
:
[
-
3
,
-
1
]}],
outputs_info
=
[
x0
,
{
"initial"
:
y0
,
"taps"
:
[
-
3
,
-
1
]}],
non_sequences
=
[
rho
],
non_sequences
=
[
rho
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
grads
=
pt
.
grad
(
outs
.
sum
(),
wrt
=
[
x0
,
y0
,
rho
])
grads
=
pt
.
grad
(
outs
.
sum
(),
wrt
=
[
x0
,
y0
,
rho
])
compare_jax_and_py
(
compare_jax_and_py
(
...
@@ -191,10 +193,11 @@ def test_scan_rng_update():
...
@@ -191,10 +193,11 @@ def test_scan_rng_update():
@pytest.mark.xfail
(
raises
=
NotImplementedError
)
@pytest.mark.xfail
(
raises
=
NotImplementedError
)
def
test_scan_while
():
def
test_scan_while
():
xs
,
_
=
scan
(
xs
=
scan
(
lambda
x
:
(
x
+
1
,
until
(
x
<
10
)),
lambda
x
:
(
x
+
1
,
until
(
x
<
10
)),
outputs_info
=
[
pt
.
zeros
(())],
outputs_info
=
[
pt
.
zeros
(())],
n_steps
=
100
,
n_steps
=
100
,
return_updates
=
False
,
)
)
compare_jax_and_py
([],
[
xs
],
[])
compare_jax_and_py
([],
[
xs
],
[])
...
@@ -210,7 +213,7 @@ def test_scan_mitsot_with_nonseq():
...
@@ -210,7 +213,7 @@ def test_scan_mitsot_with_nonseq():
res
.
name
=
"y_t"
res
.
name
=
"y_t"
return
res
return
res
y_scan_pt
,
_
=
scan
(
y_scan_pt
=
scan
(
fn
=
input_step_fn
,
fn
=
input_step_fn
,
outputs_info
=
[
outputs_info
=
[
{
{
...
@@ -223,6 +226,7 @@ def test_scan_mitsot_with_nonseq():
...
@@ -223,6 +226,7 @@ def test_scan_mitsot_with_nonseq():
non_sequences
=
[
a_pt
],
non_sequences
=
[
a_pt
],
n_steps
=
10
,
n_steps
=
10
,
name
=
"y_scan"
,
name
=
"y_scan"
,
return_updates
=
False
,
)
)
y_scan_pt
.
name
=
"y"
y_scan_pt
.
name
=
"y"
y_scan_pt
.
owner
.
inputs
[
0
]
.
name
=
"y_all"
y_scan_pt
.
owner
.
inputs
[
0
]
.
name
=
"y_all"
...
@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
...
@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
k
=
3
k
=
3
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
lambda
X
,
A
:
A
@
X
,
lambda
X
,
A
:
A
@
X
,
non_sequences
=
[
A
],
non_sequences
=
[
A
],
outputs_info
=
[
x0
],
outputs_info
=
[
x0
],
n_steps
=
n_steps
,
n_steps
=
n_steps
,
return_updates
=
False
,
)
)
x0_val
=
(
x0_val
=
(
...
@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq():
...
@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq():
A
=
pt
.
matrix
(
"A"
,
shape
=
(
k
,
k
))
A
=
pt
.
matrix
(
"A"
,
shape
=
(
k
,
k
))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
lambda
X
,
A
:
A
@
X
,
lambda
X
,
A
:
A
@
X
,
non_sequences
=
[
A
],
non_sequences
=
[
A
],
sequences
=
[
x
],
sequences
=
[
x
],
n_steps
=
n_steps
,
n_steps
=
n_steps
,
return_updates
=
False
,
)
)
x_val
=
np
.
arange
(
n_steps
*
k
,
dtype
=
config
.
floatX
)
.
reshape
(
n_steps
,
k
)
x_val
=
np
.
arange
(
n_steps
*
k
,
dtype
=
config
.
floatX
)
.
reshape
(
n_steps
,
k
)
...
@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot():
...
@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot():
B
=
pt
.
matrix
(
"B"
,
shape
=
(
3
,
3
))
B
=
pt
.
matrix
(
"B"
,
shape
=
(
3
,
3
))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm3
,
xtm1
,
A
,
B
:
A
@
xtm3
+
B
@
xtm1
,
lambda
xtm3
,
xtm1
,
A
,
B
:
A
@
xtm3
+
B
@
xtm1
,
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
non_sequences
=
[
A
,
B
],
non_sequences
=
[
A
,
B
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
x0_val
=
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
)
x0_val
=
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
)
...
@@ -310,12 +317,13 @@ def test_nd_scan_sit_sot_with_carry():
...
@@ -310,12 +317,13 @@ def test_nd_scan_sit_sot_with_carry():
return
A
@
x
,
x
.
sum
()
return
A
@
x
,
x
.
sum
()
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
step
,
step
,
outputs_info
=
[
x0
,
None
],
outputs_info
=
[
x0
,
None
],
non_sequences
=
[
A
],
non_sequences
=
[
A
],
n_steps
=
10
,
n_steps
=
10
,
mode
=
get_mode
(
"JAX"
),
mode
=
get_mode
(
"JAX"
),
return_updates
=
False
,
)
)
x0_val
=
np
.
arange
(
3
,
dtype
=
config
.
floatX
)
x0_val
=
np
.
arange
(
3
,
dtype
=
config
.
floatX
)
...
@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites():
...
@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites():
# See issue #426
# See issue #426
A
=
matrix
(
"A"
)
A
=
matrix
(
"A"
)
B
=
matrix
(
"B"
)
B
=
matrix
(
"B"
)
out
,
_
=
scan
(
lambda
a
,
b
:
a
@
b
,
outputs_info
=
[
A
],
non_sequences
=
[
B
],
n_steps
=
2
)
out
=
scan
(
lambda
a
,
b
:
a
@
b
,
outputs_info
=
[
A
],
non_sequences
=
[
B
],
n_steps
=
2
,
return_updates
=
False
,
)
compare_jax_and_py
([
A
,
B
],
[
out
],
[
np
.
eye
(
3
),
np
.
eye
(
3
)],
jax_mode
=
"JAX"
)
compare_jax_and_py
([
A
,
B
],
[
out
],
[
np
.
eye
(
3
),
np
.
eye
(
3
)],
jax_mode
=
"JAX"
)
...
@@ -353,8 +367,11 @@ def test_dynamic_sequence_length():
...
@@ -353,8 +367,11 @@ def test_dynamic_sequence_length():
x
=
pt
.
tensor
(
"x"
,
shape
=
(
None
,
3
))
x
=
pt
.
tensor
(
"x"
,
shape
=
(
None
,
3
))
out
,
_
=
scan
(
out
=
scan
(
lambda
x
:
inc_without_static_shape
(
x
),
outputs_info
=
[
None
],
sequences
=
[
x
]
lambda
x
:
inc_without_static_shape
(
x
),
outputs_info
=
[
None
],
sequences
=
[
x
],
return_updates
=
False
,
)
)
f
=
function
([
x
],
out
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
f
=
function
([
x
],
out
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
assert
sum
(
isinstance
(
node
.
op
,
Scan
)
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
Scan
)
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
)
==
1
...
@@ -364,10 +381,11 @@ def test_dynamic_sequence_length():
...
@@ -364,10 +381,11 @@ def test_dynamic_sequence_length():
np
.
testing
.
assert_allclose
(
f
(
np
.
zeros
((
0
,
3
))),
np
.
empty
((
0
,
3
)))
np
.
testing
.
assert_allclose
(
f
(
np
.
zeros
((
0
,
3
))),
np
.
empty
((
0
,
3
)))
# With known static shape we should always manage, regardless of the internal implementation
# With known static shape we should always manage, regardless of the internal implementation
out2
,
_
=
scan
(
out2
=
scan
(
lambda
x
:
pt
.
specify_shape
(
inc_without_static_shape
(
x
),
x
.
shape
),
lambda
x
:
pt
.
specify_shape
(
inc_without_static_shape
(
x
),
x
.
shape
),
outputs_info
=
[
None
],
outputs_info
=
[
None
],
sequences
=
[
x
],
sequences
=
[
x
],
return_updates
=
False
,
)
)
f2
=
function
([
x
],
out2
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
f2
=
function
([
x
],
out2
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
np
.
testing
.
assert_allclose
(
f2
([[
1
,
2
,
3
]]),
np
.
array
([[
2
,
3
,
4
]]))
np
.
testing
.
assert_allclose
(
f2
([[
1
,
2
,
3
]]),
np
.
array
([[
2
,
3
,
4
]]))
...
@@ -418,11 +436,12 @@ def SEIR_model_logp():
...
@@ -418,11 +436,12 @@ def SEIR_model_logp():
it1
=
it0
+
ct0
-
dt0
it1
=
it0
+
ct0
-
dt0
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
,
_
=
scan
(
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
=
scan
(
fn
=
seir_one_step
,
fn
=
seir_one_step
,
sequences
=
[
C_t
,
D_t
],
sequences
=
[
C_t
,
D_t
],
outputs_info
=
[
st0
,
et0
,
it0
,
None
,
None
],
outputs_info
=
[
st0
,
et0
,
it0
,
None
,
None
],
non_sequences
=
[
beta
,
gamma
,
delta
],
non_sequences
=
[
beta
,
gamma
,
delta
],
return_updates
=
False
,
)
)
st
.
name
=
"S_t"
st
.
name
=
"S_t"
et
.
name
=
"E_t"
et
.
name
=
"E_t"
...
@@ -511,11 +530,12 @@ def cyclical_reduction():
...
@@ -511,11 +530,12 @@ def cyclical_reduction():
max_iter
=
100
max_iter
=
100
tol
=
1e-7
tol
=
1e-7
(
*
_
,
A1_hat
,
norm
,
_n_steps
)
,
_
=
scan
(
(
*
_
,
A1_hat
,
norm
,
_n_steps
)
=
scan
(
step
,
step
,
outputs_info
=
[
A
,
B
,
C
,
B
,
norm
,
step_num
],
outputs_info
=
[
A
,
B
,
C
,
B
,
norm
,
step_num
],
non_sequences
=
[
tol
],
non_sequences
=
[
tol
],
n_steps
=
max_iter
,
n_steps
=
max_iter
,
return_updates
=
False
,
)
)
A1_hat
=
A1_hat
[
-
1
]
A1_hat
=
A1_hat
[
-
1
]
...
...
tests/link/numba/test_scan.py
浏览文件 @
abedb7fb
...
@@ -206,11 +206,12 @@ def test_scan_multiple_output(benchmark):
...
@@ -206,11 +206,12 @@ def test_scan_multiple_output(benchmark):
it1
=
it0
+
ct0
-
dt0
it1
=
it0
+
ct0
-
dt0
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
,
_
=
scan
(
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
=
scan
(
fn
=
seir_one_step
,
fn
=
seir_one_step
,
sequences
=
[
pt_C
,
pt_D
],
sequences
=
[
pt_C
,
pt_D
],
outputs_info
=
[
st0
,
et0
,
it0
,
logp_c
,
logp_d
],
outputs_info
=
[
st0
,
et0
,
it0
,
logp_c
,
logp_d
],
non_sequences
=
[
beta
,
gamma
,
delta
],
non_sequences
=
[
beta
,
gamma
,
delta
],
return_updates
=
False
,
)
)
st
.
name
=
"S_t"
st
.
name
=
"S_t"
et
.
name
=
"E_t"
et
.
name
=
"E_t"
...
@@ -268,7 +269,7 @@ def test_scan_tap_output():
...
@@ -268,7 +269,7 @@ def test_scan_tap_output():
y_t
.
name
=
"y_t"
y_t
.
name
=
"y_t"
return
x_t
,
y_t
,
pt
.
fill
((
10
,),
z_t
)
return
x_t
,
y_t
,
pt
.
fill
((
10
,),
z_t
)
scan_res
,
_
=
scan
(
scan_res
=
scan
(
fn
=
input_step_fn
,
fn
=
input_step_fn
,
sequences
=
[
sequences
=
[
{
{
...
@@ -297,6 +298,7 @@ def test_scan_tap_output():
...
@@ -297,6 +298,7 @@ def test_scan_tap_output():
n_steps
=
5
,
n_steps
=
5
,
name
=
"yz_scan"
,
name
=
"yz_scan"
,
strict
=
True
,
strict
=
True
,
return_updates
=
False
,
)
)
test_input_vals
=
[
test_input_vals
=
[
...
@@ -312,11 +314,12 @@ def test_scan_while():
...
@@ -312,11 +314,12 @@ def test_scan_while():
return
previous_power
*
2
,
until
(
previous_power
*
2
>
max_value
)
return
previous_power
*
2
,
until
(
previous_power
*
2
>
max_value
)
max_value
=
pt
.
scalar
()
max_value
=
pt
.
scalar
()
values
,
_
=
scan
(
values
=
scan
(
power_of_2
,
power_of_2
,
outputs_info
=
pt
.
constant
(
1.0
),
outputs_info
=
pt
.
constant
(
1.0
),
non_sequences
=
max_value
,
non_sequences
=
max_value
,
n_steps
=
1024
,
n_steps
=
1024
,
return_updates
=
False
,
)
)
test_input_vals
=
[
test_input_vals
=
[
...
@@ -331,11 +334,12 @@ def test_scan_multiple_none_output():
...
@@ -331,11 +334,12 @@ def test_scan_multiple_none_output():
def
power_step
(
prior_result
,
x
):
def
power_step
(
prior_result
,
x
):
return
prior_result
*
x
,
prior_result
*
x
*
x
,
prior_result
*
x
*
x
*
x
return
prior_result
*
x
,
prior_result
*
x
*
x
,
prior_result
*
x
*
x
*
x
result
,
_
=
scan
(
result
=
scan
(
power_step
,
power_step
,
non_sequences
=
[
A
],
non_sequences
=
[
A
],
outputs_info
=
[
pt
.
ones_like
(
A
),
None
,
None
],
outputs_info
=
[
pt
.
ones_like
(
A
),
None
,
None
],
n_steps
=
3
,
n_steps
=
3
,
return_updates
=
False
,
)
)
test_input_vals
=
(
np
.
array
([
1.0
,
2.0
]),)
test_input_vals
=
(
np
.
array
([
1.0
,
2.0
]),)
compare_numba_and_py
([
A
],
result
,
test_input_vals
)
compare_numba_and_py
([
A
],
result
,
test_input_vals
)
...
@@ -343,8 +347,12 @@ def test_scan_multiple_none_output():
...
@@ -343,8 +347,12 @@ def test_scan_multiple_none_output():
def
test_grad_sitsot
():
def
test_grad_sitsot
():
def
get_sum_of_grad
(
inp
):
def
get_sum_of_grad
(
inp
):
scan_outputs
,
_updates
=
scan
(
scan_outputs
=
scan
(
fn
=
lambda
x
:
x
*
2
,
outputs_info
=
[
inp
],
n_steps
=
5
,
mode
=
"NUMBA"
fn
=
lambda
x
:
x
*
2
,
outputs_info
=
[
inp
],
n_steps
=
5
,
mode
=
"NUMBA"
,
return_updates
=
False
,
)
)
return
grad
(
scan_outputs
.
sum
(),
inp
)
.
sum
()
return
grad
(
scan_outputs
.
sum
(),
inp
)
.
sum
()
...
@@ -362,8 +370,11 @@ def test_mitmots_basic():
...
@@ -362,8 +370,11 @@ def test_mitmots_basic():
def
inner_fct
(
seq
,
state_old
,
state_current
):
def
inner_fct
(
seq
,
state_old
,
state_current
):
return
state_old
*
2
+
state_current
+
seq
return
state_old
*
2
+
state_current
+
seq
out
,
_
=
scan
(
out
=
scan
(
inner_fct
,
sequences
=
seq
,
outputs_info
=
{
"initial"
:
init_x
,
"taps"
:
[
-
2
,
-
1
]}
inner_fct
,
sequences
=
seq
,
outputs_info
=
{
"initial"
:
init_x
,
"taps"
:
[
-
2
,
-
1
]},
return_updates
=
False
,
)
)
g_outs
=
grad
(
out
.
sum
(),
[
seq
,
init_x
])
g_outs
=
grad
(
out
.
sum
(),
[
seq
,
init_x
])
...
@@ -383,10 +394,11 @@ def test_mitmots_basic():
...
@@ -383,10 +394,11 @@ def test_mitmots_basic():
def
test_inner_graph_optimized
():
def
test_inner_graph_optimized
():
"""Test that inner graph of Scan is optimized"""
"""Test that inner graph of Scan is optimized"""
xs
=
vector
(
"xs"
)
xs
=
vector
(
"xs"
)
seq
,
_
=
scan
(
seq
=
scan
(
fn
=
lambda
x
:
log
(
1
+
x
),
fn
=
lambda
x
:
log
(
1
+
x
),
sequences
=
[
xs
],
sequences
=
[
xs
],
mode
=
get_mode
(
"NUMBA"
),
mode
=
get_mode
(
"NUMBA"
),
return_updates
=
False
,
)
)
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
...
@@ -421,13 +433,14 @@ def test_vector_taps_benchmark(benchmark):
...
@@ -421,13 +433,14 @@ def test_vector_taps_benchmark(benchmark):
sitsot2
=
(
sitsot1
+
mitsot3
)
/
np
.
sqrt
(
2
)
sitsot2
=
(
sitsot1
+
mitsot3
)
/
np
.
sqrt
(
2
)
return
mitsot3
,
sitsot2
return
mitsot3
,
sitsot2
outs
,
_
=
scan
(
outs
=
scan
(
fn
=
step
,
fn
=
step
,
sequences
=
[
seq1
,
seq2
],
sequences
=
[
seq1
,
seq2
],
outputs_info
=
[
outputs_info
=
[
dict
(
initial
=
mitsot_init
,
taps
=
[
-
2
,
-
1
]),
dict
(
initial
=
mitsot_init
,
taps
=
[
-
2
,
-
1
]),
dict
(
initial
=
sitsot_init
,
taps
=
[
-
1
]),
dict
(
initial
=
sitsot_init
,
taps
=
[
-
1
]),
],
],
return_updates
=
False
,
)
)
rng
=
np
.
random
.
default_rng
(
474
)
rng
=
np
.
random
.
default_rng
(
474
)
...
@@ -468,7 +481,7 @@ def test_inplace_taps(n_steps_constant):
...
@@ -468,7 +481,7 @@ def test_inplace_taps(n_steps_constant):
y
=
ytm1
+
1
+
ytm2
+
a
y
=
ytm1
+
1
+
ytm2
+
a
return
z
,
x
,
z
+
x
+
y
,
y
return
z
,
x
,
z
+
x
+
y
,
y
[
zs
,
xs
,
ws
,
ys
]
,
_
=
scan
(
[
zs
,
xs
,
ws
,
ys
]
=
scan
(
fn
=
step
,
fn
=
step
,
outputs_info
=
[
outputs_info
=
[
dict
(
initial
=
z0
,
taps
=
[
-
3
,
-
1
]),
dict
(
initial
=
z0
,
taps
=
[
-
3
,
-
1
]),
...
@@ -478,6 +491,7 @@ def test_inplace_taps(n_steps_constant):
...
@@ -478,6 +491,7 @@ def test_inplace_taps(n_steps_constant):
],
],
non_sequences
=
[
a
],
non_sequences
=
[
a
],
n_steps
=
n_steps
,
n_steps
=
n_steps
,
return_updates
=
False
,
)
)
numba_fn
,
_
=
compare_numba_and_py
(
numba_fn
,
_
=
compare_numba_and_py
(
[
n_steps
]
*
(
not
n_steps_constant
)
+
[
a
,
x0
,
y0
,
z0
],
[
n_steps
]
*
(
not
n_steps_constant
)
+
[
a
,
x0
,
y0
,
z0
],
...
@@ -529,10 +543,11 @@ def test_inplace_taps(n_steps_constant):
...
@@ -529,10 +543,11 @@ def test_inplace_taps(n_steps_constant):
class
TestScanSITSOTBuffer
:
class
TestScanSITSOTBuffer
:
def
buffer_tester
(
self
,
n_steps
,
op_size
,
buffer_size
,
benchmark
=
None
):
def
buffer_tester
(
self
,
n_steps
,
op_size
,
buffer_size
,
benchmark
=
None
):
x0
=
pt
.
vector
(
shape
=
(
op_size
,),
dtype
=
"float64"
)
x0
=
pt
.
vector
(
shape
=
(
op_size
,),
dtype
=
"float64"
)
xs
,
_
=
pytensor
.
scan
(
xs
=
pytensor
.
scan
(
fn
=
lambda
xtm1
:
(
xtm1
+
1
),
fn
=
lambda
xtm1
:
(
xtm1
+
1
),
outputs_info
=
[
x0
],
outputs_info
=
[
x0
],
n_steps
=
n_steps
-
1
,
# 1- makes it easier to align/misalign
n_steps
=
n_steps
-
1
,
# 1- makes it easier to align/misalign
return_updates
=
False
,
)
)
if
buffer_size
==
"unit"
:
if
buffer_size
==
"unit"
:
xs_kept
=
xs
[
-
1
]
# Only last state is used
xs_kept
=
xs
[
-
1
]
# Only last state is used
...
@@ -588,12 +603,13 @@ class TestScanMITSOTBuffer:
...
@@ -588,12 +603,13 @@ class TestScanMITSOTBuffer:
init_x
=
pt
.
vector
(
"init_x"
,
shape
=
(
2
,))
init_x
=
pt
.
vector
(
"init_x"
,
shape
=
(
2
,))
n_steps
=
pt
.
iscalar
(
"n_steps"
)
n_steps
=
pt
.
iscalar
(
"n_steps"
)
output
,
_
=
scan
(
output
=
scan
(
f_pow2
,
f_pow2
,
sequences
=
[],
sequences
=
[],
outputs_info
=
[{
"initial"
:
init_x
,
"taps"
:
[
-
2
,
-
1
]}],
outputs_info
=
[{
"initial"
:
init_x
,
"taps"
:
[
-
2
,
-
1
]}],
non_sequences
=
[],
non_sequences
=
[],
n_steps
=
n_steps_val
if
constant_n_steps
else
n_steps
,
n_steps
=
n_steps_val
if
constant_n_steps
else
n_steps
,
return_updates
=
False
,
)
)
init_x_val
=
np
.
array
([
1.0
,
2.0
],
dtype
=
init_x
.
type
.
dtype
)
init_x_val
=
np
.
array
([
1.0
,
2.0
],
dtype
=
init_x
.
type
.
dtype
)
...
...
tests/scan/test_basic.py
浏览文件 @
abedb7fb
差异被折叠。
点击展开。
tests/scan/test_rewriting.py
浏览文件 @
abedb7fb
差异被折叠。
点击展开。
tests/tensor/linalg/test_rewriting.py
浏览文件 @
abedb7fb
...
@@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
...
@@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
A
=
tensor
(
"A"
,
shape
=
(
3
,
3
))
A
=
tensor
(
"A"
,
shape
=
(
3
,
3
))
x0
=
tensor
(
"b"
,
shape
=
(
3
,
4
))
x0
=
tensor
(
"b"
,
shape
=
(
3
,
4
))
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm1
,
A
:
solve
(
A
,
xtm1
,
assume_a
=
assume_a
,
transposed
=
transposed
),
lambda
xtm1
,
A
:
solve
(
A
,
xtm1
,
assume_a
=
assume_a
,
transposed
=
transposed
),
outputs_info
=
[
x0
],
outputs_info
=
[
x0
],
non_sequences
=
[
A
],
non_sequences
=
[
A
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
fn_no_opt
=
function
(
fn_no_opt
=
function
(
...
...
tests/tensor/test_blockwise.py
浏览文件 @
abedb7fb
...
@@ -694,10 +694,11 @@ def test_blockwise_grad_core_type():
...
@@ -694,10 +694,11 @@ def test_blockwise_grad_core_type():
def
test_scan_gradient_core_type
():
def
test_scan_gradient_core_type
():
n_steps
=
3
n_steps
=
3
seq
=
tensor
(
"seq"
,
shape
=
(
n_steps
,
1
),
dtype
=
"float64"
)
seq
=
tensor
(
"seq"
,
shape
=
(
n_steps
,
1
),
dtype
=
"float64"
)
out
,
_
=
scan
(
out
=
scan
(
lambda
s
:
s
,
lambda
s
:
s
,
sequences
=
[
seq
],
sequences
=
[
seq
],
n_steps
=
n_steps
,
n_steps
=
n_steps
,
return_updates
=
False
,
)
)
vec_seq
=
tensor
(
"vec_seq"
,
shape
=
(
None
,
n_steps
,
1
),
dtype
=
"float64"
)
vec_seq
=
tensor
(
"vec_seq"
,
shape
=
(
None
,
n_steps
,
1
),
dtype
=
"float64"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论