Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
abedb7fb
提交
abedb7fb
authored
10月 31, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
11月 08, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Start using new API in tests that don't involve shared updates
上级
78293400
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
505 行增加
和
266 行删除
+505
-266
test_scan.py
tests/link/jax/test_scan.py
+40
-20
test_scan.py
tests/link/numba/test_scan.py
+29
-13
test_basic.py
tests/scan/test_basic.py
+270
-142
test_rewriting.py
tests/scan/test_rewriting.py
+162
-89
test_rewriting.py
tests/tensor/linalg/test_rewriting.py
+2
-1
test_blockwise.py
tests/tensor/test_blockwise.py
+2
-1
没有找到文件。
tests/link/jax/test_scan.py
浏览文件 @
abedb7fb
...
@@ -23,10 +23,11 @@ jax = pytest.importorskip("jax")
...
@@ -23,10 +23,11 @@ jax = pytest.importorskip("jax")
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
2
,
None
,
None
)])
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
2
,
None
,
None
)])
def
test_scan_sit_sot
(
view
):
def
test_scan_sit_sot
(
view
):
x0
=
pt
.
scalar
(
"x0"
,
dtype
=
"float64"
)
x0
=
pt
.
scalar
(
"x0"
,
dtype
=
"float64"
)
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm1
:
xtm1
+
1
,
lambda
xtm1
:
xtm1
+
1
,
outputs_info
=
[
x0
],
outputs_info
=
[
x0
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
if
view
:
if
view
:
xs
=
xs
[
view
]
xs
=
xs
[
view
]
...
@@ -37,10 +38,11 @@ def test_scan_sit_sot(view):
...
@@ -37,10 +38,11 @@ def test_scan_sit_sot(view):
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
def
test_scan_mit_sot
(
view
):
def
test_scan_mit_sot
(
view
):
x0
=
pt
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
3
,))
x0
=
pt
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
3
,))
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm3
,
xtm1
:
xtm3
+
xtm1
+
1
,
lambda
xtm3
,
xtm1
:
xtm3
+
xtm1
+
1
,
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
if
view
:
if
view
:
xs
=
xs
[
view
]
xs
=
xs
[
view
]
...
@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y):
...
@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y):
def
step
(
xtm3
,
xtm1
,
ytm4
,
ytm2
):
def
step
(
xtm3
,
xtm1
,
ytm4
,
ytm2
):
return
xtm3
+
ytm4
+
1
,
xtm1
+
ytm2
+
2
return
xtm3
+
ytm4
+
1
,
xtm1
+
ytm2
+
2
[
xs
,
ys
]
,
_
=
scan
(
[
xs
,
ys
]
=
scan
(
fn
=
step
,
fn
=
step
,
outputs_info
=
[
outputs_info
=
[
{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]},
{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]},
{
"initial"
:
y0
,
"taps"
:
[
-
4
,
-
2
]},
{
"initial"
:
y0
,
"taps"
:
[
-
4
,
-
2
]},
],
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
if
view_x
:
if
view_x
:
xs
=
xs
[
view_x
]
xs
=
xs
[
view_x
]
...
@@ -80,10 +83,8 @@ def test_scan_nit_sot(view):
...
@@ -80,10 +83,8 @@ def test_scan_nit_sot(view):
xs
=
pt
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
10
,))
xs
=
pt
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
10
,))
ys
,
_
=
scan
(
ys
=
scan
(
lambda
x
:
pt
.
exp
(
x
),
lambda
x
:
pt
.
exp
(
x
),
outputs_info
=
[
None
],
sequences
=
[
xs
],
return_updates
=
False
outputs_info
=
[
None
],
sequences
=
[
xs
],
)
)
if
view
:
if
view
:
ys
=
ys
[
view
]
ys
=
ys
[
view
]
...
@@ -106,11 +107,12 @@ def test_scan_mit_mot():
...
@@ -106,11 +107,12 @@ def test_scan_mit_mot():
rho
=
pt
.
scalar
(
"rho"
,
dtype
=
"float64"
)
rho
=
pt
.
scalar
(
"rho"
,
dtype
=
"float64"
)
x0
=
pt
.
vector
(
"xs"
,
shape
=
(
2
,))
x0
=
pt
.
vector
(
"xs"
,
shape
=
(
2
,))
y0
=
pt
.
vector
(
"ys"
,
shape
=
(
3
,))
y0
=
pt
.
vector
(
"ys"
,
shape
=
(
3
,))
[
outs
,
_
]
,
_
=
scan
(
[
outs
,
_
]
=
scan
(
step
,
step
,
outputs_info
=
[
x0
,
{
"initial"
:
y0
,
"taps"
:
[
-
3
,
-
1
]}],
outputs_info
=
[
x0
,
{
"initial"
:
y0
,
"taps"
:
[
-
3
,
-
1
]}],
non_sequences
=
[
rho
],
non_sequences
=
[
rho
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
grads
=
pt
.
grad
(
outs
.
sum
(),
wrt
=
[
x0
,
y0
,
rho
])
grads
=
pt
.
grad
(
outs
.
sum
(),
wrt
=
[
x0
,
y0
,
rho
])
compare_jax_and_py
(
compare_jax_and_py
(
...
@@ -191,10 +193,11 @@ def test_scan_rng_update():
...
@@ -191,10 +193,11 @@ def test_scan_rng_update():
@pytest.mark.xfail
(
raises
=
NotImplementedError
)
@pytest.mark.xfail
(
raises
=
NotImplementedError
)
def
test_scan_while
():
def
test_scan_while
():
xs
,
_
=
scan
(
xs
=
scan
(
lambda
x
:
(
x
+
1
,
until
(
x
<
10
)),
lambda
x
:
(
x
+
1
,
until
(
x
<
10
)),
outputs_info
=
[
pt
.
zeros
(())],
outputs_info
=
[
pt
.
zeros
(())],
n_steps
=
100
,
n_steps
=
100
,
return_updates
=
False
,
)
)
compare_jax_and_py
([],
[
xs
],
[])
compare_jax_and_py
([],
[
xs
],
[])
...
@@ -210,7 +213,7 @@ def test_scan_mitsot_with_nonseq():
...
@@ -210,7 +213,7 @@ def test_scan_mitsot_with_nonseq():
res
.
name
=
"y_t"
res
.
name
=
"y_t"
return
res
return
res
y_scan_pt
,
_
=
scan
(
y_scan_pt
=
scan
(
fn
=
input_step_fn
,
fn
=
input_step_fn
,
outputs_info
=
[
outputs_info
=
[
{
{
...
@@ -223,6 +226,7 @@ def test_scan_mitsot_with_nonseq():
...
@@ -223,6 +226,7 @@ def test_scan_mitsot_with_nonseq():
non_sequences
=
[
a_pt
],
non_sequences
=
[
a_pt
],
n_steps
=
10
,
n_steps
=
10
,
name
=
"y_scan"
,
name
=
"y_scan"
,
return_updates
=
False
,
)
)
y_scan_pt
.
name
=
"y"
y_scan_pt
.
name
=
"y"
y_scan_pt
.
owner
.
inputs
[
0
]
.
name
=
"y_all"
y_scan_pt
.
owner
.
inputs
[
0
]
.
name
=
"y_all"
...
@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
...
@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
k
=
3
k
=
3
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
lambda
X
,
A
:
A
@
X
,
lambda
X
,
A
:
A
@
X
,
non_sequences
=
[
A
],
non_sequences
=
[
A
],
outputs_info
=
[
x0
],
outputs_info
=
[
x0
],
n_steps
=
n_steps
,
n_steps
=
n_steps
,
return_updates
=
False
,
)
)
x0_val
=
(
x0_val
=
(
...
@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq():
...
@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq():
A
=
pt
.
matrix
(
"A"
,
shape
=
(
k
,
k
))
A
=
pt
.
matrix
(
"A"
,
shape
=
(
k
,
k
))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
lambda
X
,
A
:
A
@
X
,
lambda
X
,
A
:
A
@
X
,
non_sequences
=
[
A
],
non_sequences
=
[
A
],
sequences
=
[
x
],
sequences
=
[
x
],
n_steps
=
n_steps
,
n_steps
=
n_steps
,
return_updates
=
False
,
)
)
x_val
=
np
.
arange
(
n_steps
*
k
,
dtype
=
config
.
floatX
)
.
reshape
(
n_steps
,
k
)
x_val
=
np
.
arange
(
n_steps
*
k
,
dtype
=
config
.
floatX
)
.
reshape
(
n_steps
,
k
)
...
@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot():
...
@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot():
B
=
pt
.
matrix
(
"B"
,
shape
=
(
3
,
3
))
B
=
pt
.
matrix
(
"B"
,
shape
=
(
3
,
3
))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm3
,
xtm1
,
A
,
B
:
A
@
xtm3
+
B
@
xtm1
,
lambda
xtm3
,
xtm1
,
A
,
B
:
A
@
xtm3
+
B
@
xtm1
,
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
non_sequences
=
[
A
,
B
],
non_sequences
=
[
A
,
B
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
x0_val
=
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
)
x0_val
=
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
)
...
@@ -310,12 +317,13 @@ def test_nd_scan_sit_sot_with_carry():
...
@@ -310,12 +317,13 @@ def test_nd_scan_sit_sot_with_carry():
return
A
@
x
,
x
.
sum
()
return
A
@
x
,
x
.
sum
()
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
step
,
step
,
outputs_info
=
[
x0
,
None
],
outputs_info
=
[
x0
,
None
],
non_sequences
=
[
A
],
non_sequences
=
[
A
],
n_steps
=
10
,
n_steps
=
10
,
mode
=
get_mode
(
"JAX"
),
mode
=
get_mode
(
"JAX"
),
return_updates
=
False
,
)
)
x0_val
=
np
.
arange
(
3
,
dtype
=
config
.
floatX
)
x0_val
=
np
.
arange
(
3
,
dtype
=
config
.
floatX
)
...
@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites():
...
@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites():
# See issue #426
# See issue #426
A
=
matrix
(
"A"
)
A
=
matrix
(
"A"
)
B
=
matrix
(
"B"
)
B
=
matrix
(
"B"
)
out
,
_
=
scan
(
lambda
a
,
b
:
a
@
b
,
outputs_info
=
[
A
],
non_sequences
=
[
B
],
n_steps
=
2
)
out
=
scan
(
lambda
a
,
b
:
a
@
b
,
outputs_info
=
[
A
],
non_sequences
=
[
B
],
n_steps
=
2
,
return_updates
=
False
,
)
compare_jax_and_py
([
A
,
B
],
[
out
],
[
np
.
eye
(
3
),
np
.
eye
(
3
)],
jax_mode
=
"JAX"
)
compare_jax_and_py
([
A
,
B
],
[
out
],
[
np
.
eye
(
3
),
np
.
eye
(
3
)],
jax_mode
=
"JAX"
)
...
@@ -353,8 +367,11 @@ def test_dynamic_sequence_length():
...
@@ -353,8 +367,11 @@ def test_dynamic_sequence_length():
x
=
pt
.
tensor
(
"x"
,
shape
=
(
None
,
3
))
x
=
pt
.
tensor
(
"x"
,
shape
=
(
None
,
3
))
out
,
_
=
scan
(
out
=
scan
(
lambda
x
:
inc_without_static_shape
(
x
),
outputs_info
=
[
None
],
sequences
=
[
x
]
lambda
x
:
inc_without_static_shape
(
x
),
outputs_info
=
[
None
],
sequences
=
[
x
],
return_updates
=
False
,
)
)
f
=
function
([
x
],
out
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
f
=
function
([
x
],
out
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
assert
sum
(
isinstance
(
node
.
op
,
Scan
)
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
)
==
1
assert
sum
(
isinstance
(
node
.
op
,
Scan
)
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
)
==
1
...
@@ -364,10 +381,11 @@ def test_dynamic_sequence_length():
...
@@ -364,10 +381,11 @@ def test_dynamic_sequence_length():
np
.
testing
.
assert_allclose
(
f
(
np
.
zeros
((
0
,
3
))),
np
.
empty
((
0
,
3
)))
np
.
testing
.
assert_allclose
(
f
(
np
.
zeros
((
0
,
3
))),
np
.
empty
((
0
,
3
)))
# With known static shape we should always manage, regardless of the internal implementation
# With known static shape we should always manage, regardless of the internal implementation
out2
,
_
=
scan
(
out2
=
scan
(
lambda
x
:
pt
.
specify_shape
(
inc_without_static_shape
(
x
),
x
.
shape
),
lambda
x
:
pt
.
specify_shape
(
inc_without_static_shape
(
x
),
x
.
shape
),
outputs_info
=
[
None
],
outputs_info
=
[
None
],
sequences
=
[
x
],
sequences
=
[
x
],
return_updates
=
False
,
)
)
f2
=
function
([
x
],
out2
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
f2
=
function
([
x
],
out2
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
np
.
testing
.
assert_allclose
(
f2
([[
1
,
2
,
3
]]),
np
.
array
([[
2
,
3
,
4
]]))
np
.
testing
.
assert_allclose
(
f2
([[
1
,
2
,
3
]]),
np
.
array
([[
2
,
3
,
4
]]))
...
@@ -418,11 +436,12 @@ def SEIR_model_logp():
...
@@ -418,11 +436,12 @@ def SEIR_model_logp():
it1
=
it0
+
ct0
-
dt0
it1
=
it0
+
ct0
-
dt0
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
,
_
=
scan
(
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
=
scan
(
fn
=
seir_one_step
,
fn
=
seir_one_step
,
sequences
=
[
C_t
,
D_t
],
sequences
=
[
C_t
,
D_t
],
outputs_info
=
[
st0
,
et0
,
it0
,
None
,
None
],
outputs_info
=
[
st0
,
et0
,
it0
,
None
,
None
],
non_sequences
=
[
beta
,
gamma
,
delta
],
non_sequences
=
[
beta
,
gamma
,
delta
],
return_updates
=
False
,
)
)
st
.
name
=
"S_t"
st
.
name
=
"S_t"
et
.
name
=
"E_t"
et
.
name
=
"E_t"
...
@@ -511,11 +530,12 @@ def cyclical_reduction():
...
@@ -511,11 +530,12 @@ def cyclical_reduction():
max_iter
=
100
max_iter
=
100
tol
=
1e-7
tol
=
1e-7
(
*
_
,
A1_hat
,
norm
,
_n_steps
)
,
_
=
scan
(
(
*
_
,
A1_hat
,
norm
,
_n_steps
)
=
scan
(
step
,
step
,
outputs_info
=
[
A
,
B
,
C
,
B
,
norm
,
step_num
],
outputs_info
=
[
A
,
B
,
C
,
B
,
norm
,
step_num
],
non_sequences
=
[
tol
],
non_sequences
=
[
tol
],
n_steps
=
max_iter
,
n_steps
=
max_iter
,
return_updates
=
False
,
)
)
A1_hat
=
A1_hat
[
-
1
]
A1_hat
=
A1_hat
[
-
1
]
...
...
tests/link/numba/test_scan.py
浏览文件 @
abedb7fb
...
@@ -206,11 +206,12 @@ def test_scan_multiple_output(benchmark):
...
@@ -206,11 +206,12 @@ def test_scan_multiple_output(benchmark):
it1
=
it0
+
ct0
-
dt0
it1
=
it0
+
ct0
-
dt0
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
,
_
=
scan
(
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
=
scan
(
fn
=
seir_one_step
,
fn
=
seir_one_step
,
sequences
=
[
pt_C
,
pt_D
],
sequences
=
[
pt_C
,
pt_D
],
outputs_info
=
[
st0
,
et0
,
it0
,
logp_c
,
logp_d
],
outputs_info
=
[
st0
,
et0
,
it0
,
logp_c
,
logp_d
],
non_sequences
=
[
beta
,
gamma
,
delta
],
non_sequences
=
[
beta
,
gamma
,
delta
],
return_updates
=
False
,
)
)
st
.
name
=
"S_t"
st
.
name
=
"S_t"
et
.
name
=
"E_t"
et
.
name
=
"E_t"
...
@@ -268,7 +269,7 @@ def test_scan_tap_output():
...
@@ -268,7 +269,7 @@ def test_scan_tap_output():
y_t
.
name
=
"y_t"
y_t
.
name
=
"y_t"
return
x_t
,
y_t
,
pt
.
fill
((
10
,),
z_t
)
return
x_t
,
y_t
,
pt
.
fill
((
10
,),
z_t
)
scan_res
,
_
=
scan
(
scan_res
=
scan
(
fn
=
input_step_fn
,
fn
=
input_step_fn
,
sequences
=
[
sequences
=
[
{
{
...
@@ -297,6 +298,7 @@ def test_scan_tap_output():
...
@@ -297,6 +298,7 @@ def test_scan_tap_output():
n_steps
=
5
,
n_steps
=
5
,
name
=
"yz_scan"
,
name
=
"yz_scan"
,
strict
=
True
,
strict
=
True
,
return_updates
=
False
,
)
)
test_input_vals
=
[
test_input_vals
=
[
...
@@ -312,11 +314,12 @@ def test_scan_while():
...
@@ -312,11 +314,12 @@ def test_scan_while():
return
previous_power
*
2
,
until
(
previous_power
*
2
>
max_value
)
return
previous_power
*
2
,
until
(
previous_power
*
2
>
max_value
)
max_value
=
pt
.
scalar
()
max_value
=
pt
.
scalar
()
values
,
_
=
scan
(
values
=
scan
(
power_of_2
,
power_of_2
,
outputs_info
=
pt
.
constant
(
1.0
),
outputs_info
=
pt
.
constant
(
1.0
),
non_sequences
=
max_value
,
non_sequences
=
max_value
,
n_steps
=
1024
,
n_steps
=
1024
,
return_updates
=
False
,
)
)
test_input_vals
=
[
test_input_vals
=
[
...
@@ -331,11 +334,12 @@ def test_scan_multiple_none_output():
...
@@ -331,11 +334,12 @@ def test_scan_multiple_none_output():
def
power_step
(
prior_result
,
x
):
def
power_step
(
prior_result
,
x
):
return
prior_result
*
x
,
prior_result
*
x
*
x
,
prior_result
*
x
*
x
*
x
return
prior_result
*
x
,
prior_result
*
x
*
x
,
prior_result
*
x
*
x
*
x
result
,
_
=
scan
(
result
=
scan
(
power_step
,
power_step
,
non_sequences
=
[
A
],
non_sequences
=
[
A
],
outputs_info
=
[
pt
.
ones_like
(
A
),
None
,
None
],
outputs_info
=
[
pt
.
ones_like
(
A
),
None
,
None
],
n_steps
=
3
,
n_steps
=
3
,
return_updates
=
False
,
)
)
test_input_vals
=
(
np
.
array
([
1.0
,
2.0
]),)
test_input_vals
=
(
np
.
array
([
1.0
,
2.0
]),)
compare_numba_and_py
([
A
],
result
,
test_input_vals
)
compare_numba_and_py
([
A
],
result
,
test_input_vals
)
...
@@ -343,8 +347,12 @@ def test_scan_multiple_none_output():
...
@@ -343,8 +347,12 @@ def test_scan_multiple_none_output():
def
test_grad_sitsot
():
def
test_grad_sitsot
():
def
get_sum_of_grad
(
inp
):
def
get_sum_of_grad
(
inp
):
scan_outputs
,
_updates
=
scan
(
scan_outputs
=
scan
(
fn
=
lambda
x
:
x
*
2
,
outputs_info
=
[
inp
],
n_steps
=
5
,
mode
=
"NUMBA"
fn
=
lambda
x
:
x
*
2
,
outputs_info
=
[
inp
],
n_steps
=
5
,
mode
=
"NUMBA"
,
return_updates
=
False
,
)
)
return
grad
(
scan_outputs
.
sum
(),
inp
)
.
sum
()
return
grad
(
scan_outputs
.
sum
(),
inp
)
.
sum
()
...
@@ -362,8 +370,11 @@ def test_mitmots_basic():
...
@@ -362,8 +370,11 @@ def test_mitmots_basic():
def
inner_fct
(
seq
,
state_old
,
state_current
):
def
inner_fct
(
seq
,
state_old
,
state_current
):
return
state_old
*
2
+
state_current
+
seq
return
state_old
*
2
+
state_current
+
seq
out
,
_
=
scan
(
out
=
scan
(
inner_fct
,
sequences
=
seq
,
outputs_info
=
{
"initial"
:
init_x
,
"taps"
:
[
-
2
,
-
1
]}
inner_fct
,
sequences
=
seq
,
outputs_info
=
{
"initial"
:
init_x
,
"taps"
:
[
-
2
,
-
1
]},
return_updates
=
False
,
)
)
g_outs
=
grad
(
out
.
sum
(),
[
seq
,
init_x
])
g_outs
=
grad
(
out
.
sum
(),
[
seq
,
init_x
])
...
@@ -383,10 +394,11 @@ def test_mitmots_basic():
...
@@ -383,10 +394,11 @@ def test_mitmots_basic():
def
test_inner_graph_optimized
():
def
test_inner_graph_optimized
():
"""Test that inner graph of Scan is optimized"""
"""Test that inner graph of Scan is optimized"""
xs
=
vector
(
"xs"
)
xs
=
vector
(
"xs"
)
seq
,
_
=
scan
(
seq
=
scan
(
fn
=
lambda
x
:
log
(
1
+
x
),
fn
=
lambda
x
:
log
(
1
+
x
),
sequences
=
[
xs
],
sequences
=
[
xs
],
mode
=
get_mode
(
"NUMBA"
),
mode
=
get_mode
(
"NUMBA"
),
return_updates
=
False
,
)
)
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
...
@@ -421,13 +433,14 @@ def test_vector_taps_benchmark(benchmark):
...
@@ -421,13 +433,14 @@ def test_vector_taps_benchmark(benchmark):
sitsot2
=
(
sitsot1
+
mitsot3
)
/
np
.
sqrt
(
2
)
sitsot2
=
(
sitsot1
+
mitsot3
)
/
np
.
sqrt
(
2
)
return
mitsot3
,
sitsot2
return
mitsot3
,
sitsot2
outs
,
_
=
scan
(
outs
=
scan
(
fn
=
step
,
fn
=
step
,
sequences
=
[
seq1
,
seq2
],
sequences
=
[
seq1
,
seq2
],
outputs_info
=
[
outputs_info
=
[
dict
(
initial
=
mitsot_init
,
taps
=
[
-
2
,
-
1
]),
dict
(
initial
=
mitsot_init
,
taps
=
[
-
2
,
-
1
]),
dict
(
initial
=
sitsot_init
,
taps
=
[
-
1
]),
dict
(
initial
=
sitsot_init
,
taps
=
[
-
1
]),
],
],
return_updates
=
False
,
)
)
rng
=
np
.
random
.
default_rng
(
474
)
rng
=
np
.
random
.
default_rng
(
474
)
...
@@ -468,7 +481,7 @@ def test_inplace_taps(n_steps_constant):
...
@@ -468,7 +481,7 @@ def test_inplace_taps(n_steps_constant):
y
=
ytm1
+
1
+
ytm2
+
a
y
=
ytm1
+
1
+
ytm2
+
a
return
z
,
x
,
z
+
x
+
y
,
y
return
z
,
x
,
z
+
x
+
y
,
y
[
zs
,
xs
,
ws
,
ys
]
,
_
=
scan
(
[
zs
,
xs
,
ws
,
ys
]
=
scan
(
fn
=
step
,
fn
=
step
,
outputs_info
=
[
outputs_info
=
[
dict
(
initial
=
z0
,
taps
=
[
-
3
,
-
1
]),
dict
(
initial
=
z0
,
taps
=
[
-
3
,
-
1
]),
...
@@ -478,6 +491,7 @@ def test_inplace_taps(n_steps_constant):
...
@@ -478,6 +491,7 @@ def test_inplace_taps(n_steps_constant):
],
],
non_sequences
=
[
a
],
non_sequences
=
[
a
],
n_steps
=
n_steps
,
n_steps
=
n_steps
,
return_updates
=
False
,
)
)
numba_fn
,
_
=
compare_numba_and_py
(
numba_fn
,
_
=
compare_numba_and_py
(
[
n_steps
]
*
(
not
n_steps_constant
)
+
[
a
,
x0
,
y0
,
z0
],
[
n_steps
]
*
(
not
n_steps_constant
)
+
[
a
,
x0
,
y0
,
z0
],
...
@@ -529,10 +543,11 @@ def test_inplace_taps(n_steps_constant):
...
@@ -529,10 +543,11 @@ def test_inplace_taps(n_steps_constant):
class
TestScanSITSOTBuffer
:
class
TestScanSITSOTBuffer
:
def
buffer_tester
(
self
,
n_steps
,
op_size
,
buffer_size
,
benchmark
=
None
):
def
buffer_tester
(
self
,
n_steps
,
op_size
,
buffer_size
,
benchmark
=
None
):
x0
=
pt
.
vector
(
shape
=
(
op_size
,),
dtype
=
"float64"
)
x0
=
pt
.
vector
(
shape
=
(
op_size
,),
dtype
=
"float64"
)
xs
,
_
=
pytensor
.
scan
(
xs
=
pytensor
.
scan
(
fn
=
lambda
xtm1
:
(
xtm1
+
1
),
fn
=
lambda
xtm1
:
(
xtm1
+
1
),
outputs_info
=
[
x0
],
outputs_info
=
[
x0
],
n_steps
=
n_steps
-
1
,
# 1- makes it easier to align/misalign
n_steps
=
n_steps
-
1
,
# 1- makes it easier to align/misalign
return_updates
=
False
,
)
)
if
buffer_size
==
"unit"
:
if
buffer_size
==
"unit"
:
xs_kept
=
xs
[
-
1
]
# Only last state is used
xs_kept
=
xs
[
-
1
]
# Only last state is used
...
@@ -588,12 +603,13 @@ class TestScanMITSOTBuffer:
...
@@ -588,12 +603,13 @@ class TestScanMITSOTBuffer:
init_x
=
pt
.
vector
(
"init_x"
,
shape
=
(
2
,))
init_x
=
pt
.
vector
(
"init_x"
,
shape
=
(
2
,))
n_steps
=
pt
.
iscalar
(
"n_steps"
)
n_steps
=
pt
.
iscalar
(
"n_steps"
)
output
,
_
=
scan
(
output
=
scan
(
f_pow2
,
f_pow2
,
sequences
=
[],
sequences
=
[],
outputs_info
=
[{
"initial"
:
init_x
,
"taps"
:
[
-
2
,
-
1
]}],
outputs_info
=
[{
"initial"
:
init_x
,
"taps"
:
[
-
2
,
-
1
]}],
non_sequences
=
[],
non_sequences
=
[],
n_steps
=
n_steps_val
if
constant_n_steps
else
n_steps
,
n_steps
=
n_steps_val
if
constant_n_steps
else
n_steps
,
return_updates
=
False
,
)
)
init_x_val
=
np
.
array
([
1.0
,
2.0
],
dtype
=
init_x
.
type
.
dtype
)
init_x_val
=
np
.
array
([
1.0
,
2.0
],
dtype
=
init_x
.
type
.
dtype
)
...
...
tests/scan/test_basic.py
浏览文件 @
abedb7fb
...
@@ -294,7 +294,7 @@ class TestScan:
...
@@ -294,7 +294,7 @@ class TestScan:
def
test_clone
(
self
):
def
test_clone
(
self
):
a
=
vector
()
a
=
vector
()
output
,
_
=
scan
(
fn
=
lambda
x
:
x
**
2
,
sequences
=
[
a
]
)
output
=
scan
(
fn
=
lambda
x
:
x
**
2
,
sequences
=
[
a
],
return_updates
=
False
)
scan_op
=
output
.
owner
.
op
scan_op
=
output
.
owner
.
op
assert
isinstance
(
scan_op
,
Scan
)
assert
isinstance
(
scan_op
,
Scan
)
...
@@ -320,7 +320,7 @@ class TestScan:
...
@@ -320,7 +320,7 @@ class TestScan:
state
=
scalar
(
"state"
)
state
=
scalar
(
"state"
)
n_steps
=
iscalar
(
"nsteps"
)
n_steps
=
iscalar
(
"nsteps"
)
output
,
updates
=
scan
(
output
=
scan
(
f_pow2
,
f_pow2
,
[],
[],
state
,
state
,
...
@@ -328,10 +328,9 @@ class TestScan:
...
@@ -328,10 +328,9 @@ class TestScan:
n_steps
=
n_steps
,
n_steps
=
n_steps
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
_my_f
=
function
(
_my_f
=
function
([
state
,
n_steps
],
output
,
allow_input_downcast
=
True
)
[
state
,
n_steps
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
origdir
=
Path
.
cwd
()
origdir
=
Path
.
cwd
()
tmpdir
=
None
tmpdir
=
None
...
@@ -368,11 +367,9 @@ class TestScan:
...
@@ -368,11 +367,9 @@ class TestScan:
state
=
scalar
(
"state"
)
state
=
scalar
(
"state"
)
n_steps
=
iscalar
(
"nsteps"
)
n_steps
=
iscalar
(
"nsteps"
)
output
,
updates
=
scan
(
f_pow2
,
[],
state
,
[],
n_steps
=
n_steps
)
output
=
scan
(
f_pow2
,
[],
state
,
[],
n_steps
=
n_steps
,
return_updates
=
False
)
f
=
function
(
f
=
function
([
state
,
n_steps
],
output
,
allow_input_downcast
=
True
)
[
state
,
n_steps
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
scan_node
=
[
scan_node
=
[
node
for
node
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
node
.
op
,
Scan
)
node
for
node
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
node
.
op
,
Scan
)
...
@@ -410,7 +407,9 @@ class TestScan:
...
@@ -410,7 +407,9 @@ class TestScan:
return
2
*
x_tm1
return
2
*
x_tm1
n_steps
=
iscalar
(
"n_steps"
)
n_steps
=
iscalar
(
"n_steps"
)
values
,
_
=
scan
(
f_pow
,
outputs_info
=
(
x_init
,),
n_steps
=
n_steps
)
values
=
scan
(
f_pow
,
outputs_info
=
(
x_init
,),
n_steps
=
n_steps
,
return_updates
=
False
)
update_fn
=
function
((
x_init
,
n_steps
),
values
,
mode
=
mode
)
update_fn
=
function
((
x_init
,
n_steps
),
values
,
mode
=
mode
)
...
@@ -443,7 +442,9 @@ class TestScan:
...
@@ -443,7 +442,9 @@ class TestScan:
return
2
*
x_i
return
2
*
x_i
with
config
.
change_flags
(
mode
=
mode
):
with
config
.
change_flags
(
mode
=
mode
):
values
,
_
=
scan
(
inner_fn
,
outputs_info
=
(
x_init
,),
sequences
=
x
)
values
=
scan
(
inner_fn
,
outputs_info
=
(
x_init
,),
sequences
=
x
,
return_updates
=
False
)
values_fn
=
function
((
x_init
,
x
),
values
)
values_fn
=
function
((
x_init
,
x
),
values
)
assert
isinstance
(
values
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Scan
)
assert
isinstance
(
values
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Scan
)
...
@@ -474,7 +475,7 @@ class TestScan:
...
@@ -474,7 +475,7 @@ class TestScan:
return
2
*
x_i
return
2
*
x_i
with
config
.
change_flags
(
mode
=
mode
):
with
config
.
change_flags
(
mode
=
mode
):
values
,
_
=
scan
(
inner_fn
,
sequences
=
x
)
values
=
scan
(
inner_fn
,
sequences
=
x
,
return_updates
=
False
)
values_fn
=
function
((
x
,),
values
)
values_fn
=
function
((
x
,),
values
)
assert
isinstance
(
values
.
owner
.
op
,
Scan
)
assert
isinstance
(
values
.
owner
.
op
,
Scan
)
...
@@ -491,7 +492,9 @@ class TestScan:
...
@@ -491,7 +492,9 @@ class TestScan:
# Compile the PyTensor function
# Compile the PyTensor function
n_steps
=
2
n_steps
=
2
inp
=
matrix
()
inp
=
matrix
()
broadcasted_inp
,
_
=
scan
(
lambda
x
:
x
,
non_sequences
=
[
inp
],
n_steps
=
n_steps
)
broadcasted_inp
=
scan
(
lambda
x
:
x
,
non_sequences
=
[
inp
],
n_steps
=
n_steps
,
return_updates
=
False
)
out
=
broadcasted_inp
.
sum
()
out
=
broadcasted_inp
.
sum
()
gr
=
grad
(
out
,
inp
)
gr
=
grad
(
out
,
inp
)
fun
=
function
([
inp
],
[
broadcasted_inp
,
gr
])
fun
=
function
([
inp
],
[
broadcasted_inp
,
gr
])
...
@@ -519,7 +522,7 @@ class TestScan:
...
@@ -519,7 +522,7 @@ class TestScan:
W_in
=
scalar
(
"win"
)
W_in
=
scalar
(
"win"
)
W
=
scalar
(
"w"
)
W
=
scalar
(
"w"
)
output
,
updates
=
scan
(
output
=
scan
(
f_rnn
,
f_rnn
,
u
,
u
,
x0
,
x0
,
...
@@ -527,11 +530,10 @@ class TestScan:
...
@@ -527,11 +530,10 @@ class TestScan:
n_steps
=
None
,
n_steps
=
None
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
f2
=
function
(
f2
=
function
([
u
,
x0
,
W_in
,
W
],
output
,
allow_input_downcast
=
True
)
[
u
,
x0
,
W_in
,
W
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
# get random initial values
# get random initial values
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
4
,))
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
4
,))
...
@@ -561,7 +563,7 @@ class TestScan:
...
@@ -561,7 +563,7 @@ class TestScan:
def
f_rnn_shared
(
u_t
,
x_tm1
,
tmp_W_in
,
tmp_W
):
def
f_rnn_shared
(
u_t
,
x_tm1
,
tmp_W_in
,
tmp_W
):
return
u_t
*
tmp_W_in
+
x_tm1
*
tmp_W
return
u_t
*
tmp_W_in
+
x_tm1
*
tmp_W
output
,
updates
=
scan
(
output
=
scan
(
f_rnn_shared
,
f_rnn_shared
,
u
,
u
,
x0
,
x0
,
...
@@ -569,8 +571,9 @@ class TestScan:
...
@@ -569,8 +571,9 @@ class TestScan:
n_steps
=
None
,
n_steps
=
None
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
f3
=
function
([
u
,
x0
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
f3
=
function
([
u
,
x0
],
output
,
allow_input_downcast
=
True
)
# get random initial values
# get random initial values
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
4
,))
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
4
,))
...
@@ -688,11 +691,14 @@ class TestScan:
...
@@ -688,11 +691,14 @@ class TestScan:
# this test refers to a bug reported by Nicolas
# this test refers to a bug reported by Nicolas
# Boulanger-Lewandowski June 6th
# Boulanger-Lewandowski June 6th
x
=
dvector
()
x
=
dvector
()
y
,
updates
=
scan
(
y
=
scan
(
lambda
x
:
[
x
],
sequences
=
dict
(
input
=
x
,
taps
=
[
-
1
]),
outputs_info
=
[
None
]
lambda
x
:
[
x
],
sequences
=
dict
(
input
=
x
,
taps
=
[
-
1
]),
outputs_info
=
[
None
],
return_updates
=
False
,
)
)
inp
=
np
.
arange
(
5
)
.
astype
(
"float64"
)
inp
=
np
.
arange
(
5
)
.
astype
(
"float64"
)
rval
=
function
([
x
],
y
,
updates
=
updates
)(
inp
)
rval
=
function
([
x
],
y
)(
inp
)
assert
np
.
all
(
rval
==
inp
[:
-
1
])
assert
np
.
all
(
rval
==
inp
[:
-
1
])
def
test_output_only
(
self
):
def
test_output_only
(
self
):
...
@@ -701,11 +707,18 @@ class TestScan:
...
@@ -701,11 +707,18 @@ class TestScan:
u
=
vector
(
"u"
)
u
=
vector
(
"u"
)
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn
,
u
,
[],
[],
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
f_rnn
,
u
,
[],
[],
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
f2
=
function
([
u
],
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
)
f2
=
function
([
u
],
outputs
,
allow_input_downcast
=
True
)
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
5
,))
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
5
,))
...
@@ -722,7 +735,7 @@ class TestScan:
...
@@ -722,7 +735,7 @@ class TestScan:
W_in
=
scalar
(
"win"
)
W_in
=
scalar
(
"win"
)
W
=
scalar
(
"w"
)
W
=
scalar
(
"w"
)
output
,
updates
=
scan
(
output
=
scan
(
f_rnn
,
f_rnn
,
u
,
u
,
x0
,
x0
,
...
@@ -730,11 +743,10 @@ class TestScan:
...
@@ -730,11 +743,10 @@ class TestScan:
n_steps
=
None
,
n_steps
=
None
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
True
,
go_backwards
=
True
,
return_updates
=
False
,
)
)
f2
=
function
(
f2
=
function
([
u
,
x0
,
W_in
,
W
],
output
,
allow_input_downcast
=
True
)
[
u
,
x0
,
W_in
,
W
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
# get random initial values
# get random initial values
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
4
,))
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
4
,))
...
@@ -797,8 +809,8 @@ class TestScan:
...
@@ -797,8 +809,8 @@ class TestScan:
def
test_hash
(
self
):
def
test_hash
(
self
):
x
=
vector
()
x
=
vector
()
y
=
vector
()
y
=
vector
()
scan1
,
_updates
=
scan
(
lambda
_x
:
_x
+
1
,
x
)
scan1
=
scan
(
lambda
_x
:
_x
+
1
,
x
,
return_updates
=
False
)
scan2
,
_updates
=
scan
(
lambda
_x
:
_x
+
1
,
y
)
scan2
=
scan
(
lambda
_x
:
_x
+
1
,
y
,
return_updates
=
False
)
assert
scan1
.
owner
.
op
==
scan2
.
owner
.
op
assert
scan1
.
owner
.
op
==
scan2
.
owner
.
op
assert
hash
(
scan1
.
owner
.
op
)
==
hash
(
scan2
.
owner
.
op
)
assert
hash
(
scan1
.
owner
.
op
)
==
hash
(
scan2
.
owner
.
op
)
...
@@ -809,9 +821,24 @@ class TestScan:
...
@@ -809,9 +821,24 @@ class TestScan:
y
=
vector
(
"y"
)
y
=
vector
(
"y"
)
c
=
scalar
(
"c"
)
c
=
scalar
(
"c"
)
scan_a
,
_
=
scan
(
lambda
x
,
y
,
c
:
x
+
y
+
c
,
sequences
=
[
x
,
y
],
non_sequences
=
[
c
])
scan_a
=
scan
(
scan_b
,
_
=
scan
(
lambda
x
,
y
,
c
:
x
+
y
+
c
,
sequences
=
[
x
,
y
],
non_sequences
=
[
c
])
lambda
x
,
y
,
c
:
x
+
y
+
c
,
scan_c
,
_
=
scan
(
lambda
x
,
y
,
c
:
x
+
y
+
c
,
sequences
=
[
y
,
x
],
non_sequences
=
[
c
])
sequences
=
[
x
,
y
],
non_sequences
=
[
c
],
return_updates
=
False
,
)
scan_b
=
scan
(
lambda
x
,
y
,
c
:
x
+
y
+
c
,
sequences
=
[
x
,
y
],
non_sequences
=
[
c
],
return_updates
=
False
,
)
scan_c
=
scan
(
lambda
x
,
y
,
c
:
x
+
y
+
c
,
sequences
=
[
y
,
x
],
non_sequences
=
[
c
],
return_updates
=
False
,
)
assert
scan_b
is
not
scan_a
assert
scan_b
is
not
scan_a
assert
scan_c
is
not
scan_a
assert
scan_c
is
not
scan_a
...
@@ -1006,7 +1033,7 @@ class TestScan:
...
@@ -1006,7 +1033,7 @@ class TestScan:
def
lambda_fn
(
x_t
):
def
lambda_fn
(
x_t
):
return
x_t
+
1
,
until
(
x_t
>
3
)
return
x_t
+
1
,
until
(
x_t
>
3
)
o
,
_
=
scan
(
lambda_fn
,
x
)
o
=
scan
(
lambda_fn
,
x
,
return_updates
=
False
)
f
=
function
([
x
],
o
)
f
=
function
([
x
],
o
)
vx
=
np
.
zeros
((
50
,),
dtype
=
config
.
floatX
)
vx
=
np
.
zeros
((
50
,),
dtype
=
config
.
floatX
)
vx
[
23
]
=
4
vx
[
23
]
=
4
...
@@ -1019,7 +1046,7 @@ class TestScan:
...
@@ -1019,7 +1046,7 @@ class TestScan:
def
lambda_fn
(
x_t
):
def
lambda_fn
(
x_t
):
return
x_t
+
1
,
until
(
x_t
>
3
)
return
x_t
+
1
,
until
(
x_t
>
3
)
o
,
_
=
scan
(
lambda_fn
,
x
)
o
=
scan
(
lambda_fn
,
x
,
return_updates
=
False
)
f
=
function
([
x
],
o
.
shape
[
0
],
mode
=
mode_with_opt
)
f
=
function
([
x
],
o
.
shape
[
0
],
mode
=
mode_with_opt
)
vx
=
np
.
zeros
((
50
,),
dtype
=
config
.
floatX
)
vx
=
np
.
zeros
((
50
,),
dtype
=
config
.
floatX
)
...
@@ -1029,11 +1056,12 @@ class TestScan:
...
@@ -1029,11 +1056,12 @@ class TestScan:
def
test_infer_shape_nsteps_smaller_seq_length
(
self
):
def
test_infer_shape_nsteps_smaller_seq_length
(
self
):
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
[
o1
,
o2
]
,
_
=
scan
(
[
o1
,
o2
]
=
scan
(
lambda
x
,
y
:
(
x
+
1
,
y
+
x
),
lambda
x
,
y
:
(
x
+
1
,
y
+
x
),
sequences
=
x
,
sequences
=
x
,
outputs_info
=
[
None
,
x
[
0
]],
outputs_info
=
[
None
,
x
[
0
]],
n_steps
=
20
,
n_steps
=
20
,
return_updates
=
False
,
)
)
f
=
function
([
x
],
[
o1
.
shape
[
0
],
o2
.
shape
[
0
]],
mode
=
mode_with_opt
)
f
=
function
([
x
],
[
o1
.
shape
[
0
],
o2
.
shape
[
0
]],
mode
=
mode_with_opt
)
...
@@ -1071,17 +1099,18 @@ class TestScan:
...
@@ -1071,17 +1099,18 @@ class TestScan:
mode
=
MonitorMode
(
post_func
=
detect_large_outputs
)
mode
=
MonitorMode
(
post_func
=
detect_large_outputs
)
# Symbolic description of the result
# Symbolic description of the result
result
,
updates
=
scan
(
result
=
scan
(
fn
=
lambda
prior_result
,
A
:
prior_result
*
A
,
fn
=
lambda
prior_result
,
A
:
prior_result
*
A
,
outputs_info
=
pt
.
ones_like
(
A
),
outputs_info
=
pt
.
ones_like
(
A
),
non_sequences
=
A
,
non_sequences
=
A
,
n_steps
=
k
,
n_steps
=
k
,
mode
=
mode
,
mode
=
mode
,
return_updates
=
False
,
)
)
final_result
=
result
[
-
1
]
final_result
=
result
[
-
1
]
f
=
function
(
inputs
=
[
A
,
k
],
outputs
=
final_result
,
updates
=
updates
)
f
=
function
(
inputs
=
[
A
,
k
],
outputs
=
final_result
)
f
(
np
.
asarray
([
2
,
3
,
0.1
,
0
,
1
],
dtype
=
config
.
floatX
),
4
)
f
(
np
.
asarray
([
2
,
3
,
0.1
,
0
,
1
],
dtype
=
config
.
floatX
),
4
)
# There should be 3 outputs greater than 10: prior_result[0] at step 3,
# There should be 3 outputs greater than 10: prior_result[0] at step 3,
...
@@ -1103,10 +1132,11 @@ class TestScan:
...
@@ -1103,10 +1132,11 @@ class TestScan:
y
.
name
=
"y"
y
.
name
=
"y"
gy
=
grad
(
y
,
x
)
gy
=
grad
(
y
,
x
)
gy
.
name
=
"gy"
gy
.
name
=
"gy"
hy
,
_updates
=
scan
(
hy
=
scan
(
lambda
i
,
gy
,
x
:
grad
(
gy
[
i
]
*
fc2
,
x
),
lambda
i
,
gy
,
x
:
grad
(
gy
[
i
]
*
fc2
,
x
),
sequences
=
pt
.
arange
(
gy
.
shape
[
0
]),
sequences
=
pt
.
arange
(
gy
.
shape
[
0
]),
non_sequences
=
[
gy
,
x
],
non_sequences
=
[
gy
,
x
],
return_updates
=
False
,
)
)
f
=
function
([
x
,
A
],
hy
,
allow_input_downcast
=
True
)
f
=
function
([
x
,
A
],
hy
,
allow_input_downcast
=
True
)
...
@@ -1123,8 +1153,13 @@ class TestScan:
...
@@ -1123,8 +1153,13 @@ class TestScan:
def
test_sequence_is_scan
(
self
,
mode
):
def
test_sequence_is_scan
(
self
,
mode
):
"""Make sure that a `Scan` can be used as a sequence input to another `Scan`."""
"""Make sure that a `Scan` can be used as a sequence input to another `Scan`."""
x0
=
scalar
(
"x0"
)
x0
=
scalar
(
"x0"
)
scan_1
,
_
=
scan
(
lambda
x
:
x
+
1
,
outputs_info
=
{
"initial"
:
x0
},
n_steps
=
10
)
scan_1
=
scan
(
scan_2
,
_
=
scan
(
lambda
x
:
x
+
1
,
sequences
=
[
scan_1
])
lambda
x
:
x
+
1
,
outputs_info
=
{
"initial"
:
x0
},
n_steps
=
10
,
return_updates
=
False
,
)
scan_2
=
scan
(
lambda
x
:
x
+
1
,
sequences
=
[
scan_1
],
return_updates
=
False
)
with
config
.
change_flags
(
mode
=
mode
):
with
config
.
change_flags
(
mode
=
mode
):
scan_2_fn
=
function
([
x0
],
scan_2
)
scan_2_fn
=
function
([
x0
],
scan_2
)
...
@@ -1185,7 +1220,7 @@ class TestScan:
...
@@ -1185,7 +1220,7 @@ class TestScan:
def
test_blockwise_scan
(
self
):
def
test_blockwise_scan
(
self
):
x
=
pt
.
tensor
(
"x"
,
shape
=
())
x
=
pt
.
tensor
(
"x"
,
shape
=
())
out
,
_
=
scan
(
lambda
x
:
x
+
1
,
outputs_info
=
[
x
],
n_steps
=
10
)
out
=
scan
(
lambda
x
:
x
+
1
,
outputs_info
=
[
x
],
n_steps
=
10
,
return_updates
=
False
)
x_vec
=
pt
.
tensor
(
"x_vec"
,
shape
=
(
None
,))
x_vec
=
pt
.
tensor
(
"x_vec"
,
shape
=
(
None
,))
out_vec
=
vectorize_graph
(
out
,
{
x
:
x_vec
})
out_vec
=
vectorize_graph
(
out
,
{
x
:
x_vec
})
...
@@ -1203,13 +1238,14 @@ class TestScan:
...
@@ -1203,13 +1238,14 @@ class TestScan:
a0
=
shared
(
np
.
arange
(
2
))
a0
=
shared
(
np
.
arange
(
2
))
b0
=
shared
(
np
.
arange
(
2
))
b0
=
shared
(
np
.
arange
(
2
))
(
a
,
_b
)
,
_
=
scan
(
(
a
,
_b
)
=
scan
(
fn
,
fn
,
outputs_info
=
[
outputs_info
=
[
{
"initial"
:
a0
,
"taps"
:
[
-
2
,
-
1
]},
{
"initial"
:
a0
,
"taps"
:
[
-
2
,
-
1
]},
{
"initial"
:
b0
,
"taps"
:
[
-
2
,
-
1
]},
{
"initial"
:
b0
,
"taps"
:
[
-
2
,
-
1
]},
],
],
n_steps
=
2
,
n_steps
=
2
,
return_updates
=
False
,
)
)
grad
(
a
[
-
1
],
a0
)
grad
(
a
[
-
1
],
a0
)
...
@@ -1241,8 +1277,11 @@ class TestScan:
...
@@ -1241,8 +1277,11 @@ class TestScan:
state_next
=
state_old
*
2
+
state_current
+
seq
state_next
=
state_old
*
2
+
state_current
+
seq
return
state_next
return
state_next
out
,
_
=
scan
(
out
=
scan
(
inner_fct
,
sequences
=
seq
,
outputs_info
=
{
"initial"
:
x
,
"taps"
:
[
-
2
,
-
1
]}
inner_fct
,
sequences
=
seq
,
outputs_info
=
{
"initial"
:
x
,
"taps"
:
[
-
2
,
-
1
]},
return_updates
=
False
,
)
)
g_out
=
grad
(
out
.
sum
(),
[
seq
,
x
])
g_out
=
grad
(
out
.
sum
(),
[
seq
,
x
])
...
@@ -1302,12 +1341,13 @@ class TestScan:
...
@@ -1302,12 +1341,13 @@ class TestScan:
new_y
=
pt
.
switch
(
cond
,
y
,
sigmoid
(
x
))
new_y
=
pt
.
switch
(
cond
,
y
,
sigmoid
(
x
))
return
new_cond
,
new_x
,
new_y
return
new_cond
,
new_x
,
new_y
values
,
_
=
scan
(
values
=
scan
(
inner_fn
,
inner_fn
,
outputs_info
=
[
c
,
x
,
y
],
outputs_info
=
[
c
,
x
,
y
],
n_steps
=
10
,
n_steps
=
10
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
gX
,
gY
=
grad
(
values
[
1
]
.
sum
(),
[
x
,
y
])
gX
,
gY
=
grad
(
values
[
1
]
.
sum
(),
[
x
,
y
])
f
=
function
([
c
,
x
,
y
],
[
gX
,
gY
],
allow_input_downcast
=
True
)
f
=
function
([
c
,
x
,
y
],
[
gX
,
gY
],
allow_input_downcast
=
True
)
...
@@ -1762,11 +1802,12 @@ class TestScan:
...
@@ -1762,11 +1802,12 @@ class TestScan:
outputs_info
=
[
None
,
dict
(
initial
=
out_init
,
taps
=
[
-
3
])]
outputs_info
=
[
None
,
dict
(
initial
=
out_init
,
taps
=
[
-
3
])]
scan_outputs
,
_
=
scan
(
scan_outputs
=
scan
(
fn
=
inner_fct
,
fn
=
inner_fct
,
sequences
=
seq
,
sequences
=
seq
,
outputs_info
=
outputs_info
,
outputs_info
=
outputs_info
,
non_sequences
=
non_seq
,
non_sequences
=
non_seq
,
return_updates
=
False
,
)
)
# Attempt to take various gradients
# Attempt to take various gradients
...
@@ -1834,7 +1875,9 @@ class TestScan:
...
@@ -1834,7 +1875,9 @@ class TestScan:
dict
(
initial
=
out_init
[
3
],
taps
=
[
-
2
,
-
1
]),
dict
(
initial
=
out_init
[
3
],
taps
=
[
-
2
,
-
1
]),
]
]
scan_outputs
,
_
=
scan
(
fn
=
inner_fct
,
outputs_info
=
outputs_info
,
n_steps
=
10
)
scan_outputs
=
scan
(
fn
=
inner_fct
,
outputs_info
=
outputs_info
,
n_steps
=
10
,
return_updates
=
False
)
grad
(
scan_outputs
[
0
]
.
sum
(),
out_init
[
1
])
grad
(
scan_outputs
[
0
]
.
sum
(),
out_init
[
1
])
...
@@ -1857,11 +1900,12 @@ class TestScan:
...
@@ -1857,11 +1900,12 @@ class TestScan:
x
=
scalar
(
"x"
)
x
=
scalar
(
"x"
)
_max_coefficients_supported
=
1000
_max_coefficients_supported
=
1000
full_range
=
pt
.
arange
(
_max_coefficients_supported
)
full_range
=
pt
.
arange
(
_max_coefficients_supported
)
components
,
_updates
=
scan
(
components
=
scan
(
fn
=
lambda
coeff
,
power
,
free_var
:
coeff
*
(
free_var
**
power
),
fn
=
lambda
coeff
,
power
,
free_var
:
coeff
*
(
free_var
**
power
),
outputs_info
=
None
,
outputs_info
=
None
,
sequences
=
[
c
,
full_range
],
sequences
=
[
c
,
full_range
],
non_sequences
=
x
,
non_sequences
=
x
,
return_updates
=
False
,
)
)
P
=
components
.
sum
()
P
=
components
.
sum
()
dP
=
grad
(
P
,
x
)
dP
=
grad
(
P
,
x
)
...
@@ -1877,11 +1921,12 @@ class TestScan:
...
@@ -1877,11 +1921,12 @@ class TestScan:
x
=
scalar
(
"x"
)
x
=
scalar
(
"x"
)
_max_coefficients_supported
=
1000
_max_coefficients_supported
=
1000
full_range
=
pt
.
arange
(
_max_coefficients_supported
)
full_range
=
pt
.
arange
(
_max_coefficients_supported
)
components
,
_updates
=
scan
(
components
=
scan
(
fn
=
lambda
coeff
,
power
,
free_var
:
coeff
*
(
free_var
**
power
),
fn
=
lambda
coeff
,
power
,
free_var
:
coeff
*
(
free_var
**
power
),
outputs_info
=
None
,
outputs_info
=
None
,
sequences
=
[
c
,
full_range
],
sequences
=
[
c
,
full_range
],
non_sequences
=
x
,
non_sequences
=
x
,
return_updates
=
False
,
)
)
P
=
components
.
sum
()
P
=
components
.
sum
()
dP
=
grad
(
P
,
x
)
.
sum
()
dP
=
grad
(
P
,
x
)
.
sum
()
...
@@ -1968,8 +2013,13 @@ class TestScan:
...
@@ -1968,8 +2013,13 @@ class TestScan:
_W
=
specify_shape
(
W
,
v_W
.
shape
)
_W
=
specify_shape
(
W
,
v_W
.
shape
)
_W
.
name
=
"_W"
_W
.
name
=
"_W"
o
,
_
=
scan
(
o
=
scan
(
rnn_fn
,
sequences
=
_u
,
outputs_info
=
_h0
,
non_sequences
=
_W
,
name
=
"rnn_fn"
rnn_fn
,
sequences
=
_u
,
outputs_info
=
_h0
,
non_sequences
=
_W
,
name
=
"rnn_fn"
,
return_updates
=
False
,
)
)
o
=
o
[
-
1
]
o
=
o
[
-
1
]
eu
=
matrix
(
"eu"
)
eu
=
matrix
(
"eu"
)
...
@@ -1983,25 +2033,28 @@ class TestScan:
...
@@ -1983,25 +2033,28 @@ class TestScan:
[
u
,
h0
,
W
,
eu
,
eh0
,
eW
],
[
nwo_u
,
nwo_h0
,
nwo_W
],
on_unused_input
=
"ignore"
[
u
,
h0
,
W
,
eu
,
eh0
,
eW
],
[
nwo_u
,
nwo_h0
,
nwo_W
],
on_unused_input
=
"ignore"
)
)
n2o_u
,
_
=
scan
(
n2o_u
=
scan
(
lambda
i
,
o
,
u
,
h0
,
W
,
eu
:
(
grad
(
o
[
i
],
u
)
*
eu
)
.
sum
(),
lambda
i
,
o
,
u
,
h0
,
W
,
eu
:
(
grad
(
o
[
i
],
u
)
*
eu
)
.
sum
(),
sequences
=
pt
.
arange
(
o
.
shape
[
0
]),
sequences
=
pt
.
arange
(
o
.
shape
[
0
]),
non_sequences
=
[
o
,
u
,
h0
,
W
,
eu
],
non_sequences
=
[
o
,
u
,
h0
,
W
,
eu
],
name
=
"jacobU"
,
name
=
"jacobU"
,
return_updates
=
False
,
)
)
n2o_h0
,
_
=
scan
(
n2o_h0
=
scan
(
lambda
i
,
o
,
u
,
h0
,
W
,
eh0
:
(
grad
(
o
[
i
],
h0
)
*
eh0
)
.
sum
(),
lambda
i
,
o
,
u
,
h0
,
W
,
eh0
:
(
grad
(
o
[
i
],
h0
)
*
eh0
)
.
sum
(),
sequences
=
pt
.
arange
(
o
.
shape
[
0
]),
sequences
=
pt
.
arange
(
o
.
shape
[
0
]),
non_sequences
=
[
o
,
u
,
h0
,
W
,
eh0
],
non_sequences
=
[
o
,
u
,
h0
,
W
,
eh0
],
name
=
"jacobh"
,
name
=
"jacobh"
,
return_updates
=
False
,
)
)
n2o_W
,
_
=
scan
(
n2o_W
=
scan
(
lambda
i
,
o
,
u
,
h0
,
W
,
eW
:
(
grad
(
o
[
i
],
W
)
*
eW
)
.
sum
(),
lambda
i
,
o
,
u
,
h0
,
W
,
eW
:
(
grad
(
o
[
i
],
W
)
*
eW
)
.
sum
(),
sequences
=
pt
.
arange
(
o
.
shape
[
0
]),
sequences
=
pt
.
arange
(
o
.
shape
[
0
]),
non_sequences
=
[
o
,
u
,
h0
,
W
,
eW
],
non_sequences
=
[
o
,
u
,
h0
,
W
,
eW
],
name
=
"jacobW"
,
name
=
"jacobW"
,
return_updates
=
False
,
)
)
fn_test
=
function
(
fn_test
=
function
(
...
@@ -2132,10 +2185,11 @@ class TestScan:
...
@@ -2132,10 +2185,11 @@ class TestScan:
transfer
=
sigmoid
transfer
=
sigmoid
hidden_rec
,
_
=
scan
(
hidden_rec
=
scan
(
lambda
x
,
h_tm1
:
transfer
(
dot
(
h_tm1
,
W2
)
+
x
),
lambda
x
,
h_tm1
:
transfer
(
dot
(
h_tm1
,
W2
)
+
x
),
sequences
=
hidden
,
sequences
=
hidden
,
outputs_info
=
[
pt
.
zeros_like
(
hidden
[
0
])],
outputs_info
=
[
pt
.
zeros_like
(
hidden
[
0
])],
return_updates
=
False
,
)
)
hidden_rec
.
reshape
(
hidden_rec
.
reshape
(
...
@@ -2168,12 +2222,13 @@ class TestScan:
...
@@ -2168,12 +2222,13 @@ class TestScan:
def
step
(
s
,
xtm2
,
xtm1
,
z
):
def
step
(
s
,
xtm2
,
xtm1
,
z
):
return
s
*
((
xtm2
*
0
+
xtm1
)
**
2
)
*
(
z
/
2
)
return
s
*
((
xtm2
*
0
+
xtm1
)
**
2
)
*
(
z
/
2
)
xs
,
_
=
scan
(
xs
=
scan
(
step
,
step
,
sequences
=
[
seq
],
sequences
=
[
seq
],
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
(
-
2
,
-
1
)}],
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
(
-
2
,
-
1
)}],
non_sequences
=
[
z
],
non_sequences
=
[
z
],
n_steps
=
2
,
n_steps
=
2
,
return_updates
=
False
,
)
)
last_x
=
xs
[
-
1
]
last_x
=
xs
[
-
1
]
...
@@ -2254,11 +2309,12 @@ class TestScan:
...
@@ -2254,11 +2309,12 @@ class TestScan:
raise
ValueError
(
f
"Invalid case: {case}"
)
raise
ValueError
(
f
"Invalid case: {case}"
)
seq
=
vector
(
"seq"
)
seq
=
vector
(
"seq"
)
xs
,
_
=
scan
(
xs
=
scan
(
step
,
step
,
sequences
=
[
seq
],
sequences
=
[
seq
],
non_sequences
=
non_sequences
,
non_sequences
=
non_sequences
,
strict
=
strict
,
strict
=
strict
,
return_updates
=
False
,
)
)
x0
=
xs
[
0
]
x0
=
xs
[
0
]
...
@@ -2298,7 +2354,7 @@ def test_cvm_exception_handling(mode):
...
@@ -2298,7 +2354,7 @@ def test_cvm_exception_handling(mode):
def
scan_fn
():
def
scan_fn
():
return
myop
(
pt
.
as_tensor
(
1
))
return
myop
(
pt
.
as_tensor
(
1
))
res
,
_
=
scan
(
scan_fn
,
n_steps
=
4
,
mode
=
mod
e
)
res
=
scan
(
scan_fn
,
n_steps
=
4
,
mode
=
mode
,
return_updates
=
Fals
e
)
res_fn
=
function
([],
res
,
mode
=
mode
)
res_fn
=
function
([],
res
,
mode
=
mode
)
...
@@ -2328,14 +2384,14 @@ def test_cython_performance(benchmark):
...
@@ -2328,14 +2384,14 @@ def test_cython_performance(benchmark):
py_res
=
f_py
()
py_res
=
f_py
()
s_r
=
pt
.
as_tensor_variable
(
r
,
dtype
=
config
.
floatX
)
s_r
=
pt
.
as_tensor_variable
(
r
,
dtype
=
config
.
floatX
)
s_y
,
updates
=
scan
(
s_y
=
scan
(
fn
=
lambda
ri
,
rii
,
M
:
ri
+
M
*
rii
,
fn
=
lambda
ri
,
rii
,
M
:
ri
+
M
*
rii
,
sequences
=
[
s_r
[
1
:]],
sequences
=
[
s_r
[
1
:]],
non_sequences
=
[
pt
.
as_tensor_variable
(
M
,
dtype
=
config
.
floatX
)],
non_sequences
=
[
pt
.
as_tensor_variable
(
M
,
dtype
=
config
.
floatX
)],
outputs_info
=
s_r
[
0
],
outputs_info
=
s_r
[
0
],
mode
=
Mode
(
linker
=
"cvm"
,
optimizer
=
"fast_run"
),
mode
=
Mode
(
linker
=
"cvm"
,
optimizer
=
"fast_run"
),
return_updates
=
False
,
)
)
assert
not
updates
f_cvm
=
function
([],
s_y
,
mode
=
"FAST_RUN"
)
f_cvm
=
function
([],
s_y
,
mode
=
"FAST_RUN"
)
f_cvm
.
trust_input
=
True
f_cvm
.
trust_input
=
True
...
@@ -2357,9 +2413,7 @@ def test_compute_test_values():
...
@@ -2357,9 +2413,7 @@ def test_compute_test_values():
y
=
shared
(
np
.
arange
(
3
,
dtype
=
config
.
floatX
),
name
=
"y"
)
y
=
shared
(
np
.
arange
(
3
,
dtype
=
config
.
floatX
),
name
=
"y"
)
z
,
updates
=
scan
(
fn
=
lambda
u
,
v
:
u
+
v
,
sequences
=
[
x
,
y
])
z
=
scan
(
fn
=
lambda
u
,
v
:
u
+
v
,
sequences
=
[
x
,
y
],
return_updates
=
False
)
assert
not
updates
z_grad
=
grad
(
z
.
sum
(),
x
)
z_grad
=
grad
(
z
.
sum
(),
x
)
...
@@ -2368,9 +2422,9 @@ def test_compute_test_values():
...
@@ -2368,9 +2422,9 @@ def test_compute_test_values():
# Use `non_sequences` this time
# Use `non_sequences` this time
y
=
shared
(
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
),
name
=
"y"
)
y
=
shared
(
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
),
name
=
"y"
)
z
,
updates
=
scan
(
fn
=
lambda
u
,
v
:
u
+
v
,
sequences
=
[
x
],
non_sequences
=
[
y
])
z
=
scan
(
fn
=
lambda
u
,
v
:
u
+
v
,
sequences
=
[
x
],
non_sequences
=
[
y
],
return_updates
=
False
assert
not
updates
)
z_grad
=
grad
(
z
.
sum
(),
x
)
z_grad
=
grad
(
z
.
sum
(),
x
)
...
@@ -2399,20 +2453,22 @@ def test_compute_test_value_grad():
...
@@ -2399,20 +2453,22 @@ def test_compute_test_value_grad():
def
loss_ti
(
ti
,
sum_ti
,
mi
,
W
):
def
loss_ti
(
ti
,
sum_ti
,
mi
,
W
):
return
W
.
sum
()
.
sum
()
.
sum
()
+
sum_ti
return
W
.
sum
()
.
sum
()
.
sum
()
+
sum_ti
result_ti
,
_
=
scan
(
result_ti
=
scan
(
fn
=
loss_ti
,
fn
=
loss_ti
,
outputs_info
=
outputs_ti
,
outputs_info
=
outputs_ti
,
sequences
=
pt
.
arange
(
W
.
shape
[
1
],
dtype
=
"int32"
),
sequences
=
pt
.
arange
(
W
.
shape
[
1
],
dtype
=
"int32"
),
non_sequences
=
[
mi
,
W
],
non_sequences
=
[
mi
,
W
],
return_updates
=
False
,
)
)
lossmi
=
result_ti
[
-
1
]
lossmi
=
result_ti
[
-
1
]
return
sum_mi
+
lossmi
return
sum_mi
+
lossmi
result_mi
,
_
=
scan
(
result_mi
=
scan
(
fn
=
loss_mi
,
fn
=
loss_mi
,
outputs_info
=
outputs_mi
,
outputs_info
=
outputs_mi
,
sequences
=
pt
.
arange
(
W
.
shape
[
0
],
dtype
=
"int32"
),
sequences
=
pt
.
arange
(
W
.
shape
[
0
],
dtype
=
"int32"
),
non_sequences
=
[
W
],
non_sequences
=
[
W
],
return_updates
=
False
,
)
)
loss
=
result_mi
[
-
1
]
loss
=
result_mi
[
-
1
]
...
@@ -2436,11 +2492,12 @@ def test_compute_test_value_grad_cast():
...
@@ -2436,11 +2492,12 @@ def test_compute_test_value_grad_cast():
name
=
"w"
,
name
=
"w"
,
)
)
outputs
,
_
=
scan
(
outputs
=
scan
(
lambda
i
,
h
,
w
:
(
dot
(
h
[
i
],
w
),
i
),
lambda
i
,
h
,
w
:
(
dot
(
h
[
i
],
w
),
i
),
outputs_info
=
[
None
,
0
],
outputs_info
=
[
None
,
0
],
non_sequences
=
[
h
,
w
],
non_sequences
=
[
h
,
w
],
n_steps
=
3
,
n_steps
=
3
,
return_updates
=
False
,
)
)
grad
(
outputs
[
0
]
.
sum
(),
w
)
grad
(
outputs
[
0
]
.
sum
(),
w
)
...
@@ -2449,11 +2506,12 @@ def test_compute_test_value_grad_cast():
...
@@ -2449,11 +2506,12 @@ def test_compute_test_value_grad_cast():
def
test_constant_folding_n_steps
():
def
test_constant_folding_n_steps
():
# The following code used to crash at revision 2060b8f, in the constant
# The following code used to crash at revision 2060b8f, in the constant
# folding optimization step.
# folding optimization step.
res
,
_
=
scan
(
res
=
scan
(
lambda
x
:
x
*
2
,
lambda
x
:
x
*
2
,
outputs_info
=
pt
.
ones
(()),
outputs_info
=
pt
.
ones
(()),
# The constant `n_steps` was causing the crash.
# The constant `n_steps` was causing the crash.
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
with
config
.
change_flags
(
on_opt_error
=
"raise"
):
with
config
.
change_flags
(
on_opt_error
=
"raise"
):
function
([],
res
)()
function
([],
res
)()
...
@@ -2478,10 +2536,11 @@ def test_outputs_taps_check():
...
@@ -2478,10 +2536,11 @@ def test_outputs_taps_check():
def
test_inconsistent_broadcast_error
():
def
test_inconsistent_broadcast_error
():
x
=
tensor3
()
x
=
tensor3
()
initial_x
=
pt
.
constant
(
np
.
zeros
((
1
,
10
)))
initial_x
=
pt
.
constant
(
np
.
zeros
((
1
,
10
)))
y
,
_updates
=
scan
(
y
=
scan
(
fn
=
lambda
x
,
prev_x
:
x
+
prev_x
,
fn
=
lambda
x
,
prev_x
:
x
+
prev_x
,
sequences
=
x
,
sequences
=
x
,
outputs_info
=
[
dict
(
initial
=
initial_x
)],
outputs_info
=
[
dict
(
initial
=
initial_x
)],
return_updates
=
False
,
)
)
# Error, because the broadcast patterns are inconsistent.
# Error, because the broadcast patterns are inconsistent.
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
...
@@ -2509,10 +2568,11 @@ class TestGradUntil:
...
@@ -2509,10 +2568,11 @@ class TestGradUntil:
self
.
numpy_gradient
=
2
*
np
.
concatenate
([
self
.
seq
[:
7
],
z
],
axis
=
0
)
self
.
numpy_gradient
=
2
*
np
.
concatenate
([
self
.
seq
[:
7
],
z
],
axis
=
0
)
def
test_grad_until
(
self
):
def
test_grad_until
(
self
):
r
,
_
=
scan
(
r
=
scan
(
lambda
x
,
u
:
(
x
*
x
,
until
(
x
>
u
)),
lambda
x
,
u
:
(
x
*
x
,
until
(
x
>
u
)),
sequences
=
self
.
x
,
sequences
=
self
.
x
,
non_sequences
=
[
self
.
threshold
],
non_sequences
=
[
self
.
threshold
],
return_updates
=
False
,
)
)
g
=
grad
(
r
.
sum
(),
self
.
x
)
g
=
grad
(
r
.
sum
(),
self
.
x
)
f
=
function
([
self
.
x
,
self
.
threshold
],
[
r
,
g
])
f
=
function
([
self
.
x
,
self
.
threshold
],
[
r
,
g
])
...
@@ -2528,10 +2588,11 @@ class TestGradUntil:
...
@@ -2528,10 +2588,11 @@ class TestGradUntil:
X
=
matrix
(
name
=
"x"
)
X
=
matrix
(
name
=
"x"
)
arr
=
tile_array
(
self
.
seq
)
arr
=
tile_array
(
self
.
seq
)
r
,
_
=
scan
(
r
=
scan
(
lambda
x
,
u
:
(
x
*
x
,
until
(
pt_all
(
x
>
u
))),
lambda
x
,
u
:
(
x
*
x
,
until
(
pt_all
(
x
>
u
))),
sequences
=
X
,
sequences
=
X
,
non_sequences
=
[
self
.
threshold
],
non_sequences
=
[
self
.
threshold
],
return_updates
=
False
,
)
)
g
=
grad
(
r
.
sum
(),
X
)
g
=
grad
(
r
.
sum
(),
X
)
f
=
function
([
X
,
self
.
threshold
],
[
r
,
g
])
f
=
function
([
X
,
self
.
threshold
],
[
r
,
g
])
...
@@ -2542,11 +2603,12 @@ class TestGradUntil:
...
@@ -2542,11 +2603,12 @@ class TestGradUntil:
def
test_grad_until_and_truncate
(
self
):
def
test_grad_until_and_truncate
(
self
):
n
=
3
n
=
3
r
,
_
=
scan
(
r
=
scan
(
lambda
x
,
u
:
(
x
*
x
,
until
(
x
>
u
)),
lambda
x
,
u
:
(
x
*
x
,
until
(
x
>
u
)),
sequences
=
self
.
x
,
sequences
=
self
.
x
,
non_sequences
=
[
self
.
threshold
],
non_sequences
=
[
self
.
threshold
],
truncate_gradient
=
n
,
truncate_gradient
=
n
,
return_updates
=
False
,
)
)
g
=
grad
(
r
.
sum
(),
self
.
x
)
g
=
grad
(
r
.
sum
(),
self
.
x
)
f
=
function
([
self
.
x
,
self
.
threshold
],
[
r
,
g
])
f
=
function
([
self
.
x
,
self
.
threshold
],
[
r
,
g
])
...
@@ -2558,11 +2620,12 @@ class TestGradUntil:
...
@@ -2558,11 +2620,12 @@ class TestGradUntil:
def
test_grad_until_and_truncate_sequence_taps
(
self
):
def
test_grad_until_and_truncate_sequence_taps
(
self
):
n
=
3
n
=
3
r
,
_
=
scan
(
r
=
scan
(
lambda
x
,
y
,
u
:
(
x
*
y
,
until
(
y
>
u
)),
lambda
x
,
y
,
u
:
(
x
*
y
,
until
(
y
>
u
)),
sequences
=
dict
(
input
=
self
.
x
,
taps
=
[
-
2
,
0
]),
sequences
=
dict
(
input
=
self
.
x
,
taps
=
[
-
2
,
0
]),
non_sequences
=
[
self
.
threshold
],
non_sequences
=
[
self
.
threshold
],
truncate_gradient
=
n
,
truncate_gradient
=
n
,
return_updates
=
False
,
)
)
g
=
grad
(
r
.
sum
(),
self
.
x
)
g
=
grad
(
r
.
sum
(),
self
.
x
)
f
=
function
([
self
.
x
,
self
.
threshold
],
[
r
,
g
])
f
=
function
([
self
.
x
,
self
.
threshold
],
[
r
,
g
])
...
@@ -2581,8 +2644,12 @@ def test_mintap_onestep():
...
@@ -2581,8 +2644,12 @@ def test_mintap_onestep():
new_sum
=
prev_sum
+
seq_t
new_sum
=
prev_sum
+
seq_t
return
new_sum
return
new_sum
rs
,
_updates
=
scan
(
rs
=
scan
(
fn
=
accum
,
sequences
=
{
"input"
:
seq
,
"taps"
:
[
2
]},
outputs_info
=
0
,
n_steps
=
1
fn
=
accum
,
sequences
=
{
"input"
:
seq
,
"taps"
:
[
2
]},
outputs_info
=
0
,
n_steps
=
1
,
return_updates
=
False
,
)
)
f
=
function
(
inputs
=
[
seq
],
outputs
=
rs
)
f
=
function
(
inputs
=
[
seq
],
outputs
=
rs
)
...
@@ -2667,7 +2734,12 @@ def test_inner_get_vector_length():
...
@@ -2667,7 +2734,12 @@ def test_inner_get_vector_length():
def
test_profile_info
():
def
test_profile_info
():
from
pytensor.scan.utils
import
ScanProfileStats
from
pytensor.scan.utils
import
ScanProfileStats
z
,
_updates
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
True
)
z
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
True
,
return_updates
=
False
,
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
fn
=
z
.
owner
.
op
.
fn
fn
=
z
.
owner
.
op
.
fn
...
@@ -2676,8 +2748,11 @@ def test_profile_info():
...
@@ -2676,8 +2748,11 @@ def test_profile_info():
assert
fn
.
profile
.
name
==
"scan_fn"
assert
fn
.
profile
.
name
==
"scan_fn"
# Set the `ScanProfileStats` name
# Set the `ScanProfileStats` name
z
,
_updates
=
scan
(
z
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
"profile_name"
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
"profile_name"
,
return_updates
=
False
,
)
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
...
@@ -2688,7 +2763,12 @@ def test_profile_info():
...
@@ -2688,7 +2763,12 @@ def test_profile_info():
# Use an existing profile object
# Use an existing profile object
profile
=
fn
.
profile
profile
=
fn
.
profile
z
,
_updates
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
profile
)
z
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
profile
,
return_updates
=
False
,
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
fn
=
z
.
owner
.
op
.
fn
fn
=
z
.
owner
.
op
.
fn
...
@@ -2819,7 +2899,7 @@ class TestExamples:
...
@@ -2819,7 +2899,7 @@ class TestExamples:
y_tm1
+
dot
(
x_tm1
,
W_out
),
y_tm1
+
dot
(
x_tm1
,
W_out
),
]
]
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_cmpl
,
f_rnn_cmpl
,
[
u1
,
u2
],
[
u1
,
u2
],
[
None
,
None
,
x0
,
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
])],
[
None
,
None
,
x0
,
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
])],
...
@@ -2827,11 +2907,10 @@ class TestExamples:
...
@@ -2827,11 +2907,10 @@ class TestExamples:
n_steps
=
None
,
n_steps
=
None
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
f4
=
function
(
f4
=
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
allow_input_downcast
=
True
)
[
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
)
# compute the values in numpy
# compute the values in numpy
v_x
=
np
.
zeros
((
3
,
2
),
dtype
=
config
.
floatX
)
v_x
=
np
.
zeros
((
3
,
2
),
dtype
=
config
.
floatX
)
...
@@ -2857,8 +2936,12 @@ class TestExamples:
...
@@ -2857,8 +2936,12 @@ class TestExamples:
def
scanStep
(
prev
,
seq
,
f1
):
def
scanStep
(
prev
,
seq
,
f1
):
return
prev
+
f1
*
seq
return
prev
+
f1
*
seq
scanned
,
_
=
scan
(
scanned
=
scan
(
fn
=
scanStep
,
sequences
=
[
seq
],
outputs_info
=
[
to_scan
],
non_sequences
=
[
f1
]
fn
=
scanStep
,
sequences
=
[
seq
],
outputs_info
=
[
to_scan
],
non_sequences
=
[
f1
],
return_updates
=
False
,
)
)
function
(
inputs
=
[
to_scan
,
seq
,
f1
],
outputs
=
scanned
,
allow_input_downcast
=
True
)
function
(
inputs
=
[
to_scan
,
seq
,
f1
],
outputs
=
scanned
,
allow_input_downcast
=
True
)
...
@@ -2879,8 +2962,12 @@ class TestExamples:
...
@@ -2879,8 +2962,12 @@ class TestExamples:
expr
=
dot
(
h_tm1
,
W
)
+
x_t
expr
=
dot
(
h_tm1
,
W
)
+
x_t
return
expr
return
expr
expr
,
_
=
scan
(
expr
=
scan
(
fn
=
one_step
,
sequences
=
[
inpt
],
outputs_info
=
[
initial
],
non_sequences
=
[
W
]
fn
=
one_step
,
sequences
=
[
inpt
],
outputs_info
=
[
initial
],
non_sequences
=
[
W
],
return_updates
=
False
,
)
)
v1
=
shared
(
np
.
ones
(
5
,
dtype
=
config
.
floatX
))
v1
=
shared
(
np
.
ones
(
5
,
dtype
=
config
.
floatX
))
...
@@ -2917,11 +3004,12 @@ class TestExamples:
...
@@ -2917,11 +3004,12 @@ class TestExamples:
x
=
scalar
()
x
=
scalar
()
seq
=
vector
()
seq
=
vector
()
outputs_info
=
[
x
,
pt
.
zeros_like
(
x
)]
outputs_info
=
[
x
,
pt
.
zeros_like
(
x
)]
(
out1
,
out2
)
,
_updates
=
scan
(
(
out1
,
out2
)
=
scan
(
lambda
a
,
b
,
c
:
(
a
+
b
,
b
+
c
),
lambda
a
,
b
,
c
:
(
a
+
b
,
b
+
c
),
sequences
=
seq
,
sequences
=
seq
,
outputs_info
=
outputs_info
,
outputs_info
=
outputs_info
,
mode
=
mode
,
mode
=
mode
,
return_updates
=
False
,
)
)
# Obtain a reference to the scan outputs before the subtensor and
# Obtain a reference to the scan outputs before the subtensor and
...
@@ -2956,8 +3044,11 @@ class TestExamples:
...
@@ -2956,8 +3044,11 @@ class TestExamples:
x
=
dcol
()
x
=
dcol
()
seq
=
dcol
()
seq
=
dcol
()
outputs_info
=
[
x
,
pt
.
zeros_like
(
x
)]
outputs_info
=
[
x
,
pt
.
zeros_like
(
x
)]
(
out1
,
out2
),
_updates
=
scan
(
(
out1
,
out2
)
=
scan
(
lambda
a
,
b
,
c
:
(
a
+
b
,
a
+
c
),
sequences
=
seq
,
outputs_info
=
outputs_info
lambda
a
,
b
,
c
:
(
a
+
b
,
a
+
c
),
sequences
=
seq
,
outputs_info
=
outputs_info
,
return_updates
=
False
,
)
)
# Obtain a reference to the scan outputs before the subtensor and
# Obtain a reference to the scan outputs before the subtensor and
...
@@ -3096,7 +3187,9 @@ class TestExamples:
...
@@ -3096,7 +3187,9 @@ class TestExamples:
seq
=
matrix
()
seq
=
matrix
()
initial_value
=
shared
(
np
.
zeros
((
4
,
1
),
dtype
=
config
.
floatX
))
initial_value
=
shared
(
np
.
zeros
((
4
,
1
),
dtype
=
config
.
floatX
))
outputs_info
=
[{
"initial"
:
initial_value
,
"taps"
:
[
-
4
]},
None
]
outputs_info
=
[{
"initial"
:
initial_value
,
"taps"
:
[
-
4
]},
None
]
results
,
_updates
=
scan
(
fn
=
onestep
,
sequences
=
seq
,
outputs_info
=
outputs_info
)
results
=
scan
(
fn
=
onestep
,
sequences
=
seq
,
outputs_info
=
outputs_info
,
return_updates
=
False
)
f
=
function
([
seq
],
results
[
1
])
f
=
function
([
seq
],
results
[
1
])
assert
np
.
all
(
exp_out
==
f
(
inp
))
assert
np
.
all
(
exp_out
==
f
(
inp
))
...
@@ -3119,7 +3212,9 @@ class TestExamples:
...
@@ -3119,7 +3212,9 @@ class TestExamples:
seq
=
matrix
()
seq
=
matrix
()
initial_value
=
shared
(
np
.
zeros
((
4
,
1
),
dtype
=
config
.
floatX
))
initial_value
=
shared
(
np
.
zeros
((
4
,
1
),
dtype
=
config
.
floatX
))
outputs_info
=
[{
"initial"
:
initial_value
,
"taps"
:
[
-
4
]},
None
]
outputs_info
=
[{
"initial"
:
initial_value
,
"taps"
:
[
-
4
]},
None
]
results
,
_
=
scan
(
fn
=
onestep
,
sequences
=
seq
,
outputs_info
=
outputs_info
)
results
=
scan
(
fn
=
onestep
,
sequences
=
seq
,
outputs_info
=
outputs_info
,
return_updates
=
False
)
sharedvar
=
shared
(
np
.
zeros
((
1
,
1
),
dtype
=
config
.
floatX
))
sharedvar
=
shared
(
np
.
zeros
((
1
,
1
),
dtype
=
config
.
floatX
))
updates
=
{
sharedvar
:
results
[
0
][
-
1
:]}
updates
=
{
sharedvar
:
results
[
0
][
-
1
:]}
...
@@ -3164,7 +3259,7 @@ class TestExamples:
...
@@ -3164,7 +3259,7 @@ class TestExamples:
init
=
matrix
()
init
=
matrix
()
outputs_info
=
[
None
,
None
,
None
,
None
,
dict
(
initial
=
init
,
taps
=
[
-
3
,
-
2
,
-
1
])]
outputs_info
=
[
None
,
None
,
None
,
None
,
dict
(
initial
=
init
,
taps
=
[
-
3
,
-
2
,
-
1
])]
out
,
_
=
scan
(
inner_fn
,
outputs_info
=
outputs_info
,
n_steps
=
3
)
out
=
scan
(
inner_fn
,
outputs_info
=
outputs_info
,
n_steps
=
3
,
return_updates
=
False
)
fct
=
function
([
init
],
out
)
fct
=
function
([
init
],
out
)
# Compare obtained outputs with expected outputs
# Compare obtained outputs with expected outputs
...
@@ -3197,21 +3292,23 @@ class TestExamples:
...
@@ -3197,21 +3292,23 @@ class TestExamples:
def
loss_inner
(
sum_inner
,
W
):
def
loss_inner
(
sum_inner
,
W
):
return
sum_inner
+
(
W
**
2
)
.
sum
()
return
sum_inner
+
(
W
**
2
)
.
sum
()
result_inner
,
_
=
scan
(
result_inner
=
scan
(
fn
=
loss_inner
,
fn
=
loss_inner
,
outputs_info
=
pt
.
as_tensor_variable
(
np
.
asarray
(
0
,
dtype
=
np
.
float32
)),
outputs_info
=
pt
.
as_tensor_variable
(
np
.
asarray
(
0
,
dtype
=
np
.
float32
)),
non_sequences
=
[
W
],
non_sequences
=
[
W
],
n_steps
=
1
,
n_steps
=
1
,
return_updates
=
False
,
)
)
return
sum_outer
+
result_inner
[
-
1
]
return
sum_outer
+
result_inner
[
-
1
]
# Also test return_list for that case.
# Also test return_list for that case.
result_outer
,
_
=
scan
(
result_outer
=
scan
(
fn
=
loss_outer
,
fn
=
loss_outer
,
outputs_info
=
pt
.
as_tensor_variable
(
np
.
asarray
(
0
,
dtype
=
np
.
float32
)),
outputs_info
=
pt
.
as_tensor_variable
(
np
.
asarray
(
0
,
dtype
=
np
.
float32
)),
non_sequences
=
[
W
],
non_sequences
=
[
W
],
n_steps
=
n_steps
,
n_steps
=
n_steps
,
return_list
=
True
,
return_list
=
True
,
return_updates
=
False
,
)
)
cost
=
result_outer
[
0
][
-
1
]
cost
=
result_outer
[
0
][
-
1
]
...
@@ -3230,7 +3327,9 @@ class TestExamples:
...
@@ -3230,7 +3327,9 @@ class TestExamples:
x0
=
vector
(
"X"
)
x0
=
vector
(
"X"
)
y0
=
vector
(
"y0"
)
y0
=
vector
(
"y0"
)
z0
=
vector
(
"Z"
)
z0
=
vector
(
"Z"
)
[
x
,
y
,
z
],
_
=
scan
(
inner_fn
,
outputs_info
=
[
x0
,
y0
,
z0
],
n_steps
=
10
)
[
x
,
y
,
z
]
=
scan
(
inner_fn
,
outputs_info
=
[
x0
,
y0
,
z0
],
n_steps
=
10
,
return_updates
=
False
)
cost
=
(
x
+
y
+
z
)
.
sum
()
cost
=
(
x
+
y
+
z
)
.
sum
()
grad
(
cost
,
x0
)
# defined
grad
(
cost
,
x0
)
# defined
...
@@ -3247,7 +3346,12 @@ class TestExamples:
...
@@ -3247,7 +3346,12 @@ class TestExamples:
m
=
matrix
(
"m"
)
m
=
matrix
(
"m"
)
u0
=
pt
.
zeros
((
7
,))
u0
=
pt
.
zeros
((
7
,))
[
_u
,
m2
],
_
=
scan
(
lambda
_
,
u
:
[
u
,
v
],
sequences
=
m
,
outputs_info
=
[
u0
,
None
])
[
_u
,
m2
]
=
scan
(
lambda
_
,
u
:
[
u
,
v
],
sequences
=
m
,
outputs_info
=
[
u0
,
None
],
return_updates
=
False
,
)
# This used to raise an exception with older versions because for a
# This used to raise an exception with older versions because for a
# disconnected gradient a non disconnected type was returned
# disconnected gradient a non disconnected type was returned
grad
((
m
*
m2
)
.
sum
(),
v
)
grad
((
m
*
m2
)
.
sum
(),
v
)
...
@@ -3257,8 +3361,11 @@ class TestExamples:
...
@@ -3257,8 +3361,11 @@ class TestExamples:
m
=
matrix
(
"m"
)
m
=
matrix
(
"m"
)
u0
=
pt
.
zeros
((
7
,))
u0
=
pt
.
zeros
((
7
,))
[
_u
,
m2
],
_
=
scan
(
[
_u
,
m2
]
=
scan
(
lambda
x
,
u
:
[
x
+
u
,
u
+
v
],
sequences
=
m
,
outputs_info
=
[
u0
,
None
]
lambda
x
,
u
:
[
x
+
u
,
u
+
v
],
sequences
=
m
,
outputs_info
=
[
u0
,
None
],
return_updates
=
False
,
)
)
# This used to raise an exception with older versions because
# This used to raise an exception with older versions because
# scan could not detect the connection between `m2` and `x`
# scan could not detect the connection between `m2` and `x`
...
@@ -3278,7 +3385,7 @@ class TestExamples:
...
@@ -3278,7 +3385,7 @@ class TestExamples:
out2
=
out1
+
1
out2
=
out1
+
1
return
out1
,
out2
return
out1
,
out2
[
_out1
,
out2
]
,
_
=
scan
(
step
,
sequences
=
v
)
[
_out1
,
out2
]
=
scan
(
step
,
sequences
=
v
,
return_updates
=
False
)
gv
=
grad
(
out2
.
sum
(),
[
v
])
gv
=
grad
(
out2
.
sum
(),
[
v
])
f
=
function
([
v
],
gv
)
f
=
function
([
v
],
gv
)
...
@@ -3289,7 +3396,13 @@ class TestExamples:
...
@@ -3289,7 +3396,13 @@ class TestExamples:
def
test_grad_bug_disconnected_input
(
self
):
def
test_grad_bug_disconnected_input
(
self
):
W
=
shared
(
np
.
zeros
((
3
,
3
)),
name
=
"W"
)
W
=
shared
(
np
.
zeros
((
3
,
3
)),
name
=
"W"
)
v
=
ivector
(
name
=
"v"
)
v
=
ivector
(
name
=
"v"
)
y
,
_
=
scan
(
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
W
)
y
=
scan
(
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
W
,
return_updates
=
False
,
)
# This used to raise an exception
# This used to raise an exception
f
=
function
([
v
],
grad
(
y
.
sum
(),
W
))
f
=
function
([
v
],
grad
(
y
.
sum
(),
W
))
...
@@ -3299,10 +3412,8 @@ class TestExamples:
...
@@ -3299,10 +3412,8 @@ class TestExamples:
w
=
shared
(
np
.
array
(
0
,
dtype
=
"float32"
),
name
=
"w"
)
w
=
shared
(
np
.
array
(
0
,
dtype
=
"float32"
),
name
=
"w"
)
init
=
fscalar
(
"init"
)
init
=
fscalar
(
"init"
)
out
,
_
=
scan
(
out
=
scan
(
fn
=
lambda
prev
:
w
,
fn
=
lambda
prev
:
w
,
outputs_info
=
init
,
n_steps
=
2
,
return_updates
=
False
outputs_info
=
init
,
n_steps
=
2
,
)
)
grad
(
out
[
-
1
],
w
)
grad
(
out
[
-
1
],
w
)
...
@@ -3326,7 +3437,7 @@ class TestExamples:
...
@@ -3326,7 +3437,7 @@ class TestExamples:
def
f_rnn_shared
(
u_tm2
,
x_tm1
,
x_tm2
):
def
f_rnn_shared
(
u_tm2
,
x_tm1
,
x_tm2
):
return
u_tm2
*
W_in
+
x_tm1
*
W
+
x_tm2
return
u_tm2
*
W_in
+
x_tm1
*
W
+
x_tm2
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_shared
,
f_rnn_shared
,
dict
(
input
=
u
,
taps
=-
2
),
dict
(
input
=
u
,
taps
=-
2
),
dict
(
initial
=
x0
,
taps
=
[
-
1
,
-
2
]),
dict
(
initial
=
x0
,
taps
=
[
-
1
,
-
2
]),
...
@@ -3334,9 +3445,10 @@ class TestExamples:
...
@@ -3334,9 +3445,10 @@ class TestExamples:
n_steps
=
None
,
n_steps
=
None
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
f7
=
function
([
u
,
x0
],
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
)
f7
=
function
([
u
,
x0
],
outputs
,
allow_input_downcast
=
True
)
pytensor_out
=
f7
(
vu
,
vx0
)
pytensor_out
=
f7
(
vu
,
vx0
)
# compute output in numpy
# compute output in numpy
...
@@ -3372,7 +3484,7 @@ class TestExamples:
...
@@ -3372,7 +3484,7 @@ class TestExamples:
def
f_rnn_shared
(
u_tm2
,
u_tp2
,
x_tm1
,
x_tm2
):
def
f_rnn_shared
(
u_tm2
,
u_tp2
,
x_tm1
,
x_tm2
):
return
(
u_tm2
+
u_tp2
)
*
W_in
+
x_tm1
*
W
+
x_tm2
return
(
u_tm2
+
u_tp2
)
*
W_in
+
x_tm1
*
W
+
x_tm2
output
,
updates
=
scan
(
output
=
scan
(
f_rnn_shared
,
f_rnn_shared
,
dict
(
input
=
u
,
taps
=
[
-
2
,
2
]),
dict
(
input
=
u
,
taps
=
[
-
2
,
2
]),
dict
(
initial
=
x0
,
taps
=
[
-
1
,
-
2
]),
dict
(
initial
=
x0
,
taps
=
[
-
1
,
-
2
]),
...
@@ -3380,9 +3492,10 @@ class TestExamples:
...
@@ -3380,9 +3492,10 @@ class TestExamples:
n_steps
=
None
,
n_steps
=
None
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
f8
=
function
([
u
,
x0
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
f8
=
function
([
u
,
x0
],
output
,
allow_input_downcast
=
True
)
pytensor_out
=
f8
(
vu
,
vx0
)
pytensor_out
=
f8
(
vu
,
vx0
)
# compute output in numpy
# compute output in numpy
numpy_out
=
np
.
zeros
(
2
)
numpy_out
=
np
.
zeros
(
2
)
...
@@ -3404,7 +3517,7 @@ class TestExamples:
...
@@ -3404,7 +3517,7 @@ class TestExamples:
state
=
scalar
(
"state"
)
state
=
scalar
(
"state"
)
n_steps
=
iscalar
(
"nsteps"
)
n_steps
=
iscalar
(
"nsteps"
)
# Test return_list at the same time.
# Test return_list at the same time.
output
,
updates
=
scan
(
output
=
scan
(
f_pow2
,
f_pow2
,
[],
[],
state
,
state
,
...
@@ -3413,10 +3526,9 @@ class TestExamples:
...
@@ -3413,10 +3526,9 @@ class TestExamples:
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
return_list
=
True
,
return_list
=
True
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
my_f
=
function
(
my_f
=
function
([
state
,
n_steps
],
output
,
allow_input_downcast
=
True
)
[
state
,
n_steps
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
state
=
rng
.
uniform
()
state
=
rng
.
uniform
()
...
@@ -3446,10 +3558,11 @@ class TestExamples:
...
@@ -3446,10 +3558,11 @@ class TestExamples:
pre_h
=
dot
(
x
,
W_x
)
pre_h
=
dot
(
x
,
W_x
)
return
pre_h
return
pre_h
value
,
_scan_updates
=
scan
(
value
=
scan
(
_active
,
_active
,
sequences
=
X
,
sequences
=
X
,
outputs_info
=
[
pt
.
alloc
(
floatx
(
0.0
),
1
,
out_size
)],
outputs_info
=
[
pt
.
alloc
(
floatx
(
0.0
),
1
,
out_size
)],
return_updates
=
False
,
)
)
cost
=
mean
(
value
)
cost
=
mean
(
value
)
gW_x
=
grad
(
cost
,
W_x
)
gW_x
=
grad
(
cost
,
W_x
)
...
@@ -3467,7 +3580,7 @@ class TestExamples:
...
@@ -3467,7 +3580,7 @@ class TestExamples:
condition
=
until
(
new_value
>
max_value
)
condition
=
until
(
new_value
>
max_value
)
return
[
new_value
,
new_step
],
condition
return
[
new_value
,
new_step
],
condition
rs
,
_updates
=
scan
(
fn
=
accum
,
outputs_info
=
[
0
,
0
],
n_steps
=
n_steps
)
rs
=
scan
(
fn
=
accum
,
outputs_info
=
[
0
,
0
],
n_steps
=
n_steps
,
return_updates
=
False
)
f
=
function
(
inputs
=
[
max_value
,
n_steps
],
outputs
=
rs
)
f
=
function
(
inputs
=
[
max_value
,
n_steps
],
outputs
=
rs
)
...
@@ -3487,33 +3600,37 @@ class TestExamples:
...
@@ -3487,33 +3600,37 @@ class TestExamples:
# Generate the components of the polynomial
# Generate the components of the polynomial
full_range
=
pt
.
arange
(
max_coefficients_supported
)
full_range
=
pt
.
arange
(
max_coefficients_supported
)
components
,
_updates
=
scan
(
components
=
scan
(
fn
=
lambda
coeff
,
power
,
free_var
:
coeff
*
(
free_var
**
power
),
fn
=
lambda
coeff
,
power
,
free_var
:
coeff
*
(
free_var
**
power
),
sequences
=
[
coefficients
,
full_range
],
sequences
=
[
coefficients
,
full_range
],
non_sequences
=
x
,
non_sequences
=
x
,
return_updates
=
False
,
)
)
polynomial1
=
components
.
sum
()
polynomial1
=
components
.
sum
()
polynomial2
,
_updates
=
scan
(
polynomial2
=
scan
(
fn
=
lambda
coeff
,
power
,
prev
,
free_var
:
prev
+
coeff
*
(
free_var
**
power
),
fn
=
lambda
coeff
,
power
,
prev
,
free_var
:
prev
+
coeff
*
(
free_var
**
power
),
outputs_info
=
pt
.
constant
(
0
,
dtype
=
"floatX"
),
outputs_info
=
pt
.
constant
(
0
,
dtype
=
"floatX"
),
sequences
=
[
coefficients
,
full_range
],
sequences
=
[
coefficients
,
full_range
],
non_sequences
=
x
,
non_sequences
=
x
,
return_updates
=
False
,
)
)
# python int
# python int
polynomial3
,
_updates
=
scan
(
polynomial3
=
scan
(
fn
=
lambda
coeff
,
power
,
prev
,
free_var
:
prev
+
coeff
*
(
free_var
**
power
),
fn
=
lambda
coeff
,
power
,
prev
,
free_var
:
prev
+
coeff
*
(
free_var
**
power
),
outputs_info
=
0
,
outputs_info
=
0
,
sequences
=
[
coefficients
,
full_range
],
sequences
=
[
coefficients
,
full_range
],
non_sequences
=
x
,
non_sequences
=
x
,
return_updates
=
False
,
)
)
# python float
# python float
polynomial4
,
_updates
=
scan
(
polynomial4
=
scan
(
fn
=
lambda
coeff
,
power
,
prev
,
free_var
:
prev
+
coeff
*
(
free_var
**
power
),
fn
=
lambda
coeff
,
power
,
prev
,
free_var
:
prev
+
coeff
*
(
free_var
**
power
),
outputs_info
=
0.0
,
outputs_info
=
0.0
,
sequences
=
[
coefficients
,
full_range
],
sequences
=
[
coefficients
,
full_range
],
non_sequences
=
x
,
non_sequences
=
x
,
return_updates
=
False
,
)
)
calculate_polynomial
=
function
(
calculate_polynomial
=
function
(
...
@@ -3576,8 +3693,12 @@ class TestExamples:
...
@@ -3576,8 +3693,12 @@ class TestExamples:
# o = v + 1 # <-- this line works
# o = v + 1 # <-- this line works
return
o
return
o
OS
,
_updates
=
scan
(
OS
=
scan
(
fn
=
one_step
,
sequences
=
V
,
outputs_info
=
[
None
],
non_sequences
=
[
W
]
fn
=
one_step
,
sequences
=
V
,
outputs_info
=
[
None
],
non_sequences
=
[
W
],
return_updates
=
False
,
)
)
O
=
OS
.
sum
()
+
W
.
sum
()
O
=
OS
.
sum
()
+
W
.
sum
()
...
@@ -3591,11 +3712,12 @@ class TestExamples:
...
@@ -3591,11 +3712,12 @@ class TestExamples:
)
)
def
test_infershape_seq_shorter_nsteps
(
self
):
def
test_infershape_seq_shorter_nsteps
(
self
):
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
[
o1
,
o2
]
,
_
=
scan
(
[
o1
,
o2
]
=
scan
(
lambda
x
,
y
:
(
x
+
1
,
y
+
x
),
lambda
x
,
y
:
(
x
+
1
,
y
+
x
),
sequences
=
x
,
sequences
=
x
,
outputs_info
=
[
None
,
x
[
0
]],
outputs_info
=
[
None
,
x
[
0
]],
n_steps
=
20
,
n_steps
=
20
,
return_updates
=
False
,
)
)
f
=
function
([
x
],
[
o1
,
o2
],
mode
=
mode_with_opt
)
f
=
function
([
x
],
[
o1
,
o2
],
mode
=
mode_with_opt
)
...
@@ -3667,10 +3789,14 @@ class TestExamples:
...
@@ -3667,10 +3789,14 @@ class TestExamples:
condition
=
until
(
previous_val
>
5
)
condition
=
until
(
previous_val
>
5
)
return
new_val
,
condition
return
new_val
,
condition
out
,
_
updates
=
scan
(
inner_fct
,
outputs_info
=
x
,
n_steps
=
10
)
out
,
updates
=
scan
(
inner_fct
,
outputs_info
=
x
,
n_steps
=
10
)
g_out
=
grad
(
out
.
sum
(),
x
)
g_out
=
grad
(
out
.
sum
(),
x
)
fct
=
function
([
x
],
[
out
,
g_out
])
fct
=
function
(
[
x
],
[
out
,
g_out
],
updates
=
updates
,
)
for
i
in
range
(
-
5
,
5
):
for
i
in
range
(
-
5
,
5
):
output
,
g_output
=
fct
(
i
)
output
,
g_output
=
fct
(
i
)
...
@@ -3702,7 +3828,7 @@ class TestExamples:
...
@@ -3702,7 +3828,7 @@ class TestExamples:
)
)
return
next_sitsot_val
,
next_mitsot_val
,
nitsot_out
return
next_sitsot_val
,
next_mitsot_val
,
nitsot_out
out
,
_updates
=
scan
(
out
=
scan
(
fn
=
step
,
fn
=
step
,
sequences
=
seq
,
sequences
=
seq
,
outputs_info
=
[
outputs_info
=
[
...
@@ -3711,6 +3837,7 @@ class TestExamples:
...
@@ -3711,6 +3837,7 @@ class TestExamples:
None
,
None
,
],
],
n_steps
=
5
,
n_steps
=
5
,
return_updates
=
False
,
)
)
f
=
function
([
seq
,
sitsot_init
,
mitsot_init
],
out
[
2
]
.
shape
)
f
=
function
([
seq
,
sitsot_init
,
mitsot_init
],
out
[
2
]
.
shape
)
...
@@ -3746,7 +3873,7 @@ class TestExamples:
...
@@ -3746,7 +3873,7 @@ class TestExamples:
dot
(
x_tm1
,
W_out
),
dot
(
x_tm1
,
W_out
),
]
]
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_cmpl
,
f_rnn_cmpl
,
[
u1
,
u2
],
[
u1
,
u2
],
[
x0
,
y0
],
[
x0
,
y0
],
...
@@ -3754,11 +3881,10 @@ class TestExamples:
...
@@ -3754,11 +3881,10 @@ class TestExamples:
n_steps
=
None
,
n_steps
=
None
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
f4
=
function
(
f4
=
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
allow_input_downcast
=
True
)
[
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
)
# compute the values in numpy
# compute the values in numpy
v_x
=
np
.
zeros
((
3
,
2
),
dtype
=
config
.
floatX
)
v_x
=
np
.
zeros
((
3
,
2
),
dtype
=
config
.
floatX
)
...
@@ -3802,7 +3928,7 @@ class TestExamples:
...
@@ -3802,7 +3928,7 @@ class TestExamples:
dot
(
u1_t
,
W_in1
),
dot
(
u1_t
,
W_in1
),
]
]
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_cmpl
,
f_rnn_cmpl
,
[
u1
,
dict
(
input
=
u2
,
taps
=
[
-
1
,
0
,
1
])],
[
u1
,
dict
(
input
=
u2
,
taps
=
[
-
1
,
0
,
1
])],
[
x0
,
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
]),
None
],
[
x0
,
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
]),
None
],
...
@@ -3810,11 +3936,10 @@ class TestExamples:
...
@@ -3810,11 +3936,10 @@ class TestExamples:
n_steps
=
None
,
n_steps
=
None
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
f
=
function
(
f
=
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
allow_input_downcast
=
True
)
[
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
)
ny0
=
np
.
zeros
((
5
,
2
))
ny0
=
np
.
zeros
((
5
,
2
))
ny1
=
np
.
zeros
((
5
,))
ny1
=
np
.
zeros
((
5
,))
...
@@ -3904,13 +4029,14 @@ class TestExamples:
...
@@ -3904,13 +4029,14 @@ class TestExamples:
return
[
h_t
,
y_t
]
return
[
h_t
,
y_t
]
# hidden and outputs of the entire sequence
# hidden and outputs of the entire sequence
[
_h
,
y
]
,
_
=
scan
(
[
_h
,
y
]
=
scan
(
fn
=
one_step
,
fn
=
one_step
,
sequences
=
dict
(
input
=
x
),
sequences
=
dict
(
input
=
x
),
# corresponds to the return type of one_step
# corresponds to the return type of one_step
outputs_info
=
[
dict
(
initial
=
h0
,
taps
=
[
-
2
,
-
1
]),
None
],
outputs_info
=
[
dict
(
initial
=
h0
,
taps
=
[
-
2
,
-
1
]),
None
],
non_sequences
=
[
W_ih
,
W_hh
,
b_h
,
W_ho
,
b_o
],
non_sequences
=
[
W_ih
,
W_hh
,
b_h
,
W_ho
,
b_o
],
mode
=
mode
,
mode
=
mode
,
return_updates
=
False
,
)
)
# target values
# target values
...
@@ -4084,7 +4210,7 @@ def test_output_storage_reuse(linker_mode):
...
@@ -4084,7 +4210,7 @@ def test_output_storage_reuse(linker_mode):
outer-output arrays are initialized using the outer-input arrays, the
outer-output arrays are initialized using the outer-input arrays, the
shape difference needs to be handled correctly.
shape difference needs to be handled correctly.
"""
"""
s_in_y
,
_
=
scan
(
s_in_y
=
scan
(
fn
=
lambda
z
:
(
z
+
1
,
until
(
z
>
2
)),
fn
=
lambda
z
:
(
z
+
1
,
until
(
z
>
2
)),
outputs_info
=
[
outputs_info
=
[
{
"taps"
:
[
-
1
],
"initial"
:
pt
.
as_tensor
(
0.0
,
dtype
=
np
.
float64
)}
{
"taps"
:
[
-
1
],
"initial"
:
pt
.
as_tensor
(
0.0
,
dtype
=
np
.
float64
)}
...
@@ -4092,16 +4218,18 @@ def test_output_storage_reuse(linker_mode):
...
@@ -4092,16 +4218,18 @@ def test_output_storage_reuse(linker_mode):
mode
=
mode
,
mode
=
mode
,
n_steps
=
n
-
1
,
n_steps
=
n
-
1
,
allow_gc
=
False
,
allow_gc
=
False
,
return_updates
=
False
,
)
)
return
s_in_y
.
sum
()
return
s_in_y
.
sum
()
s_y
,
_updates
=
scan
(
s_y
=
scan
(
fn
=
fn
,
fn
=
fn
,
outputs_info
=
[
None
],
outputs_info
=
[
None
],
sequences
=
[
pt
.
as_tensor
([
3
,
2
,
1
],
dtype
=
np
.
int64
)],
sequences
=
[
pt
.
as_tensor
([
3
,
2
,
1
],
dtype
=
np
.
int64
)],
mode
=
mode
,
mode
=
mode
,
allow_gc
=
False
,
allow_gc
=
False
,
return_updates
=
False
,
)
)
f_cvm
=
function
([],
s_y
,
mode
=
mode
)
f_cvm
=
function
([],
s_y
,
mode
=
mode
)
...
@@ -4121,14 +4249,14 @@ def test_rng_outputs_info():
...
@@ -4121,14 +4249,14 @@ def test_rng_outputs_info():
)
.
owner
.
outputs
)
.
owner
.
outputs
return
next_x
,
next_rng
return
next_x
,
next_rng
[
xs
,
rng_final
]
,
updates
=
scan
(
[
xs
,
rng_final
]
=
scan
(
fn
=
step
,
fn
=
step
,
outputs_info
=
[
x0
,
rng_x0
],
outputs_info
=
[
x0
,
rng_x0
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
assert
isinstance
(
xs
.
type
,
TensorType
)
assert
isinstance
(
xs
.
type
,
TensorType
)
assert
isinstance
(
rng_final
.
type
,
RandomGeneratorType
)
assert
isinstance
(
rng_final
.
type
,
RandomGeneratorType
)
assert
not
updates
fn
=
function
([
rng_init
],
[
xs
,
rng_final
])
fn
=
function
([
rng_init
],
[
xs
,
rng_final
])
xs_eval
,
rng_final_eval
=
fn
(
np
.
random
.
default_rng
(
0
))
xs_eval
,
rng_final_eval
=
fn
(
np
.
random
.
default_rng
(
0
))
...
...
tests/scan/test_rewriting.py
浏览文件 @
abedb7fb
...
@@ -47,38 +47,47 @@ class TestRemoveConstantsAndUnusedInputsScan:
...
@@ -47,38 +47,47 @@ class TestRemoveConstantsAndUnusedInputsScan:
"""Test the rewrite `remove_constants_and_unused_inputs_scan` for non-sequences."""
"""Test the rewrite `remove_constants_and_unused_inputs_scan` for non-sequences."""
W
=
matrix
(
name
=
"W"
)
W
=
matrix
(
name
=
"W"
)
v
=
ivector
(
name
=
"v"
)
v
=
ivector
(
name
=
"v"
)
y1
,
_
=
scan
(
y1
=
scan
(
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
]
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
],
return_updates
=
False
,
)
)
y2
,
_
=
scan
(
y2
=
scan
(
lambda
i
,
_
,
W
:
W
[
i
],
lambda
i
,
_
,
W
:
W
[
i
],
sequences
=
v
,
sequences
=
v
,
outputs_info
=
None
,
outputs_info
=
None
,
non_sequences
=
[
W
[
0
],
W
],
non_sequences
=
[
W
[
0
],
W
],
return_updates
=
False
,
)
)
y3
,
_
=
scan
(
y3
=
scan
(
lambda
i
,
W
,
_
:
W
[
i
],
lambda
i
,
W
,
_
:
W
[
i
],
sequences
=
v
,
sequences
=
v
,
outputs_info
=
None
,
outputs_info
=
None
,
non_sequences
=
[
W
,
W
[
0
]],
non_sequences
=
[
W
,
W
[
0
]],
return_updates
=
False
,
)
)
y4
,
_
=
scan
(
y4
=
scan
(
lambda
i
,
_
,
_2
,
W
:
W
[
i
],
lambda
i
,
_
,
_2
,
W
:
W
[
i
],
sequences
=
v
,
sequences
=
v
,
outputs_info
=
None
,
outputs_info
=
None
,
non_sequences
=
[
W
[
0
],
W
[
0
],
W
],
non_sequences
=
[
W
[
0
],
W
[
0
],
W
],
return_updates
=
False
,
)
)
y5
,
_
=
scan
(
y5
=
scan
(
lambda
i
,
_
,
W
,
_2
:
W
[
i
],
lambda
i
,
_
,
W
,
_2
:
W
[
i
],
sequences
=
v
,
sequences
=
v
,
outputs_info
=
None
,
outputs_info
=
None
,
non_sequences
=
[
W
[
0
],
W
,
W
[
0
]],
non_sequences
=
[
W
[
0
],
W
,
W
[
0
]],
return_updates
=
False
,
)
)
y6
,
_
=
scan
(
y6
=
scan
(
lambda
i
,
W
,
_
,
_2
:
W
[
i
],
lambda
i
,
W
,
_
,
_2
:
W
[
i
],
sequences
=
v
,
sequences
=
v
,
outputs_info
=
None
,
outputs_info
=
None
,
non_sequences
=
[
W
,
W
[
0
],
W
[
0
]],
non_sequences
=
[
W
,
W
[
0
],
W
[
0
]],
return_updates
=
False
,
)
)
# TODO: y7 have problem during run time. I think it should
# TODO: y7 have problem during run time. I think it should
# raise an error during the scan construction.
# raise an error during the scan construction.
...
@@ -112,47 +121,61 @@ class TestRemoveConstantsAndUnusedInputsScan:
...
@@ -112,47 +121,61 @@ class TestRemoveConstantsAndUnusedInputsScan:
W
=
matrix
(
name
=
"W"
)
W
=
matrix
(
name
=
"W"
)
v
=
ivector
(
name
=
"v"
)
v
=
ivector
(
name
=
"v"
)
vv
=
matrix
(
name
=
"vv"
)
vv
=
matrix
(
name
=
"vv"
)
y1
,
_
=
scan
(
y1
=
scan
(
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
]
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
],
return_updates
=
False
,
)
)
y2
,
_
=
scan
(
y2
=
scan
(
lambda
i
,
_
,
W
:
W
[
i
],
sequences
=
[
v
,
v
],
outputs_info
=
None
,
non_sequences
=
W
lambda
i
,
_
,
W
:
W
[
i
],
sequences
=
[
v
,
v
],
outputs_info
=
None
,
non_sequences
=
W
,
return_updates
=
False
,
)
)
y3
,
_
=
scan
(
y3
=
scan
(
lambda
i
,
_
,
W
:
W
[
i
],
lambda
i
,
_
,
W
:
W
[
i
],
sequences
=
[
v
,
vv
[
0
]],
sequences
=
[
v
,
vv
[
0
]],
outputs_info
=
None
,
outputs_info
=
None
,
non_sequences
=
W
,
non_sequences
=
W
,
return_updates
=
False
,
)
)
y4
,
_
=
scan
(
y4
=
scan
(
lambda
_
,
i
,
W
:
W
[
i
],
lambda
_
,
i
,
W
:
W
[
i
],
sequences
=
[
vv
[
0
],
v
],
sequences
=
[
vv
[
0
],
v
],
outputs_info
=
None
,
outputs_info
=
None
,
non_sequences
=
W
,
non_sequences
=
W
,
return_updates
=
False
,
)
)
y5
,
_
=
scan
(
y5
=
scan
(
lambda
_
,
i
,
_2
,
W
:
W
[
i
],
lambda
_
,
i
,
_2
,
W
:
W
[
i
],
sequences
=
[
vv
,
v
,
vv
[
0
]],
sequences
=
[
vv
,
v
,
vv
[
0
]],
outputs_info
=
None
,
outputs_info
=
None
,
non_sequences
=
W
,
non_sequences
=
W
,
return_updates
=
False
,
)
)
y6
,
_
=
scan
(
y6
=
scan
(
lambda
_
,
_2
,
i
,
W
:
W
[
i
],
lambda
_
,
_2
,
i
,
W
:
W
[
i
],
sequences
=
[
vv
[
0
],
vv
,
v
],
sequences
=
[
vv
[
0
],
vv
,
v
],
outputs_info
=
None
,
outputs_info
=
None
,
non_sequences
=
W
,
non_sequences
=
W
,
return_updates
=
False
,
)
)
y7
,
_
=
scan
(
y7
=
scan
(
lambda
i
,
_
,
_2
,
W
:
W
[
i
],
lambda
i
,
_
,
_2
,
W
:
W
[
i
],
sequences
=
[
v
,
vv
[
0
],
vv
[
0
]],
sequences
=
[
v
,
vv
[
0
],
vv
[
0
]],
outputs_info
=
None
,
outputs_info
=
None
,
non_sequences
=
W
,
non_sequences
=
W
,
return_updates
=
False
,
)
)
y8
,
_
=
scan
(
y8
=
scan
(
lambda
_
,
i
,
W
,
_2
,
_3
:
W
[
i
],
lambda
_
,
i
,
W
,
_2
,
_3
:
W
[
i
],
sequences
=
[
vv
[
0
],
v
],
sequences
=
[
vv
[
0
],
v
],
outputs_info
=
None
,
outputs_info
=
None
,
non_sequences
=
[
W
,
W
[
0
],
W
[
0
]],
non_sequences
=
[
W
,
W
[
0
],
W
[
0
]],
return_updates
=
False
,
)
)
W_val
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
W_val
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
...
@@ -195,7 +218,7 @@ class TestPushOutDot:
...
@@ -195,7 +218,7 @@ class TestPushOutDot:
def
lambda_fn
(
h
,
W1
,
W2
):
def
lambda_fn
(
h
,
W1
,
W2
):
return
dot
(
h
,
W1
+
W2
)
return
dot
(
h
,
W1
+
W2
)
o
,
_
=
scan
(
lambda_fn
,
non_sequences
=
[
h0
,
W1
,
W2
],
n_steps
=
5
)
o
=
scan
(
lambda_fn
,
non_sequences
=
[
h0
,
W1
,
W2
],
n_steps
=
5
,
return_updates
=
False
)
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
self
.
mode
)
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
self
.
mode
)
...
@@ -232,19 +255,24 @@ class TestPushOutDot:
...
@@ -232,19 +255,24 @@ class TestPushOutDot:
return
dot
(
W1
,
W2
),
until_condition
return
dot
(
W1
,
W2
),
until_condition
# Compile a function with the optimization
# Compile a function with the optimization
o
,
_
=
scan
(
o
=
scan
(
lambda_fn
,
sequences
=
[
step_indices
,
W1
],
non_sequences
=
[
W2
],
n_steps
=
5
lambda_fn
,
sequences
=
[
step_indices
,
W1
],
non_sequences
=
[
W2
],
n_steps
=
5
,
return_updates
=
False
,
)
)
f
=
function
([
W1
,
W2
,
step_indices
],
o
,
mode
=
self
.
mode
)
f
=
function
([
W1
,
W2
,
step_indices
],
o
,
mode
=
self
.
mode
)
# Compule an pytensor function without the optimization
# Compule an pytensor function without the optimization
o
,
_
=
scan
(
o
=
scan
(
lambda_fn
,
lambda_fn
,
sequences
=
[
step_indices
,
W1
],
sequences
=
[
step_indices
,
W1
],
non_sequences
=
[
W2
],
non_sequences
=
[
W2
],
n_steps
=
5
,
n_steps
=
5
,
mode
=
"FAST_COMPILE"
,
mode
=
"FAST_COMPILE"
,
return_updates
=
False
,
)
)
f_ref
=
function
([
W1
,
W2
,
step_indices
],
o
,
mode
=
self
.
mode
)
f_ref
=
function
([
W1
,
W2
,
step_indices
],
o
,
mode
=
self
.
mode
)
...
@@ -268,7 +296,13 @@ class TestPushOutDot:
...
@@ -268,7 +296,13 @@ class TestPushOutDot:
def
lambda_fn
(
h
,
W1
,
W2
):
def
lambda_fn
(
h
,
W1
,
W2
):
return
dot
(
h
,
W1
+
W2
)
return
dot
(
h
,
W1
+
W2
)
o
,
_
=
scan
(
lambda_fn
,
outputs_info
=
h0
,
non_sequences
=
[
W1
,
W2
],
n_steps
=
5
)
o
=
scan
(
lambda_fn
,
outputs_info
=
h0
,
non_sequences
=
[
W1
,
W2
],
n_steps
=
5
,
return_updates
=
False
,
)
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
self
.
mode
)
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
self
.
mode
)
...
@@ -290,10 +324,11 @@ class TestPushOutDot:
...
@@ -290,10 +324,11 @@ class TestPushOutDot:
def
fn
(
i
,
i_tm1
):
def
fn
(
i
,
i_tm1
):
return
i
+
10
,
i_tm1
return
i
+
10
,
i_tm1
([
i_t
,
i_tm1
],
_
)
=
scan
(
[
i_t
,
i_tm1
]
=
scan
(
fn
,
fn
,
sequences
=
[
inp
],
sequences
=
[
inp
],
outputs_info
=
[
np
.
asarray
([
0.0
,
0.0
],
config
.
floatX
),
None
],
outputs_info
=
[
np
.
asarray
([
0.0
,
0.0
],
config
.
floatX
),
None
],
return_updates
=
False
,
)
)
f
=
function
([
inp
],
[
i_t
,
i_tm1
])
f
=
function
([
inp
],
[
i_t
,
i_tm1
])
val
=
np
.
arange
(
10
)
.
reshape
(
5
,
2
)
.
astype
(
config
.
floatX
)
val
=
np
.
arange
(
10
)
.
reshape
(
5
,
2
)
.
astype
(
config
.
floatX
)
...
@@ -397,17 +432,18 @@ class TestPushOutNonSeqScan:
...
@@ -397,17 +432,18 @@ class TestPushOutNonSeqScan:
@config.change_flags
(
on_opt_error
=
"raise"
)
@config.change_flags
(
on_opt_error
=
"raise"
)
def
test_pushout_seqs2
(
self
):
def
test_pushout_seqs2
(
self
):
x
=
matrix
()
x
=
matrix
()
outputs
,
updates
=
scan
(
outputs
=
scan
(
lambda
x
:
[
x
*
x
,
pt
.
constant
(
0
)
.
copy
()
.
copy
()],
lambda
x
:
[
x
*
x
,
pt
.
constant
(
0
)
.
copy
()
.
copy
()],
n_steps
=
2
,
n_steps
=
2
,
sequences
=
[],
sequences
=
[],
non_sequences
=
[],
non_sequences
=
[],
outputs_info
=
[
x
,
None
],
outputs_info
=
[
x
,
None
],
return_updates
=
False
,
)
)
# Compile an PyTensor function where any optimization error will lead to
# Compile an PyTensor function where any optimization error will lead to
# an exception being raised
# an exception being raised
function
([
x
],
outputs
,
updates
=
updates
)
function
([
x
],
outputs
)
@config.change_flags
(
on_opt_error
=
"raise"
)
@config.change_flags
(
on_opt_error
=
"raise"
)
def
test_pushout_nonseq
(
self
):
def
test_pushout_nonseq
(
self
):
...
@@ -418,7 +454,9 @@ class TestPushOutNonSeqScan:
...
@@ -418,7 +454,9 @@ class TestPushOutNonSeqScan:
outputs. This led the optimization to raise an exception.
outputs. This led the optimization to raise an exception.
"""
"""
outputs
,
_
=
scan
(
lambda
x
:
(
x
*
x
,
x
),
non_sequences
=
[
2
],
n_steps
=
2
)
outputs
=
scan
(
lambda
x
:
(
x
*
x
,
x
),
non_sequences
=
[
2
],
n_steps
=
2
,
return_updates
=
False
)
f
=
function
(
inputs
=
[],
outputs
=
outputs
)
f
=
function
(
inputs
=
[],
outputs
=
outputs
)
outs
=
f
()
outs
=
f
()
...
@@ -583,10 +621,12 @@ class TestPushOutNonSeqScan:
...
@@ -583,10 +621,12 @@ class TestPushOutNonSeqScan:
test_ofg
=
OpFromGraph
([],
[
y
])
test_ofg
=
OpFromGraph
([],
[
y
])
def
inner_func
(
x
):
def
inner_func
(
x
):
out
,
_
=
pytensor
.
scan
(
lambda
:
test_ofg
(),
n_steps
=
x
)
out
=
pytensor
.
scan
(
lambda
:
test_ofg
(),
n_steps
=
x
,
return_updates
=
False
)
return
out
return
out
out
,
_
=
pytensor
.
scan
(
inner_func
,
sequences
=
[
pt
.
arange
(
1
,
2
)])
out
=
pytensor
.
scan
(
inner_func
,
sequences
=
[
pt
.
arange
(
1
,
2
)],
return_updates
=
False
)
_
=
pytensor
.
function
([],
test_ofg
())
_
=
pytensor
.
function
([],
test_ofg
())
...
@@ -612,10 +652,11 @@ class TestPushOutAddScan:
...
@@ -612,10 +652,11 @@ class TestPushOutAddScan:
def
test_sum_dot
(
self
):
def
test_sum_dot
(
self
):
A
=
matrix
(
"A"
)
A
=
matrix
(
"A"
)
B
=
matrix
(
"B"
)
B
=
matrix
(
"B"
)
S
,
_
=
scan
(
S
=
scan
(
lambda
x1
,
x2
,
u
:
u
+
dot
(
x1
,
x2
),
lambda
x1
,
x2
,
u
:
u
+
dot
(
x1
,
x2
),
sequences
=
[
A
.
dimshuffle
(
0
,
1
,
"x"
),
B
.
dimshuffle
(
0
,
"x"
,
1
)],
sequences
=
[
A
.
dimshuffle
(
0
,
1
,
"x"
),
B
.
dimshuffle
(
0
,
"x"
,
1
)],
outputs_info
=
[
pt
.
zeros_like
(
A
)],
outputs_info
=
[
pt
.
zeros_like
(
A
)],
return_updates
=
False
,
)
)
# FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that.
# FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that.
# They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]`
# They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]`
...
@@ -636,13 +677,17 @@ class TestPushOutAddScan:
...
@@ -636,13 +677,17 @@ class TestPushOutAddScan:
bv
=
pt
.
zeros
((
5
,))
bv
=
pt
.
zeros
((
5
,))
bh
=
pt
.
zeros
((
4
,))
bh
=
pt
.
zeros
((
4
,))
v
=
matrix
(
"v"
)
v
=
matrix
(
"v"
)
(
bv_t
,
bh_t
),
_
=
scan
(
(
bv_t
,
bh_t
)
=
scan
(
lambda
_
:
[
bv
,
bh
],
sequences
=
v
,
outputs_info
=
[
None
,
None
]
lambda
_
:
[
bv
,
bh
],
sequences
=
v
,
outputs_info
=
[
None
,
None
],
return_updates
=
False
,
)
)
chain
,
_
=
scan
(
chain
=
scan
(
lambda
x
:
dot
(
dot
(
x
,
W
)
+
bh_t
,
W
.
T
)
+
bv_t
,
lambda
x
:
dot
(
dot
(
x
,
W
)
+
bh_t
,
W
.
T
)
+
bv_t
,
outputs_info
=
v
,
outputs_info
=
v
,
n_steps
=
2
,
n_steps
=
2
,
return_updates
=
False
,
)
)
# TODO FIXME: Make this a real test and assert something.
# TODO FIXME: Make this a real test and assert something.
chain_fn
=
function
([
v
],
chain
)
chain_fn
=
function
([
v
],
chain
)
...
@@ -710,26 +755,28 @@ class TestPushOutAddScan:
...
@@ -710,26 +755,28 @@ class TestPushOutAddScan:
# Compile the function twice, once with the optimization and once
# Compile the function twice, once with the optimization and once
# without
# without
opt_mode
=
mode
.
including
(
"scan"
)
opt_mode
=
mode
.
including
(
"scan"
)
h
,
_
=
pytensor
.
scan
(
h
=
pytensor
.
scan
(
rnn_step1
,
rnn_step1
,
sequences
=
[
x
,
ri
,
zi
],
sequences
=
[
x
,
ri
,
zi
],
n_steps
=
seq_len
,
n_steps
=
seq_len
,
outputs_info
=
init
,
outputs_info
=
init
,
name
=
"fpass1"
,
name
=
"fpass1"
,
mode
=
opt_mode
,
mode
=
opt_mode
,
return_updates
=
False
,
)
)
cost
=
h
[
-
1
]
.
sum
()
cost
=
h
[
-
1
]
.
sum
()
grad1
=
grad
(
cost
,
[
U
,
V
,
W
])
grad1
=
grad
(
cost
,
[
U
,
V
,
W
])
f_opt
=
pytensor
.
function
(
inputs
=
[
x
,
ri
,
zi
],
outputs
=
grad1
,
mode
=
opt_mode
)
f_opt
=
pytensor
.
function
(
inputs
=
[
x
,
ri
,
zi
],
outputs
=
grad1
,
mode
=
opt_mode
)
no_opt_mode
=
mode
.
excluding
(
"scan_pushout_add"
)
no_opt_mode
=
mode
.
excluding
(
"scan_pushout_add"
)
h
,
_
=
pytensor
.
scan
(
h
=
pytensor
.
scan
(
rnn_step1
,
rnn_step1
,
sequences
=
[
x
,
ri
,
zi
],
sequences
=
[
x
,
ri
,
zi
],
n_steps
=
seq_len
,
n_steps
=
seq_len
,
outputs_info
=
init
,
outputs_info
=
init
,
name
=
"fpass1"
,
name
=
"fpass1"
,
mode
=
no_opt_mode
,
mode
=
no_opt_mode
,
return_updates
=
False
,
)
)
cost
=
h
[
-
1
]
.
sum
()
cost
=
h
[
-
1
]
.
sum
()
grad1
=
grad
(
cost
,
[
U
,
V
,
W
])
grad1
=
grad
(
cost
,
[
U
,
V
,
W
])
...
@@ -773,21 +820,23 @@ class TestPushOutAddScan:
...
@@ -773,21 +820,23 @@ class TestPushOutAddScan:
# Compile the function twice, once with the optimization and once without
# Compile the function twice, once with the optimization and once without
opt_mode
=
mode
.
including
(
"scan"
)
opt_mode
=
mode
.
including
(
"scan"
)
h
,
_
=
pytensor
.
scan
(
h
=
pytensor
.
scan
(
inner_fct
,
inner_fct
,
sequences
=
[
input1
,
input2
,
input3
],
sequences
=
[
input1
,
input2
,
input3
],
outputs_info
=
init
,
outputs_info
=
init
,
mode
=
opt_mode
,
mode
=
opt_mode
,
return_updates
=
False
,
)
)
output
=
h
[
-
1
]
output
=
h
[
-
1
]
f_opt
=
pytensor
.
function
([
input1
,
input2
,
input3
],
output
,
mode
=
opt_mode
)
f_opt
=
pytensor
.
function
([
input1
,
input2
,
input3
],
output
,
mode
=
opt_mode
)
no_opt_mode
=
mode
.
excluding
(
"scan_pushout_add"
)
no_opt_mode
=
mode
.
excluding
(
"scan_pushout_add"
)
h
,
_
=
pytensor
.
scan
(
h
=
pytensor
.
scan
(
inner_fct
,
inner_fct
,
sequences
=
[
input1
,
input2
,
input3
],
sequences
=
[
input1
,
input2
,
input3
],
outputs_info
=
init
,
outputs_info
=
init
,
mode
=
no_opt_mode
,
mode
=
no_opt_mode
,
return_updates
=
False
,
)
)
output
=
h
[
-
1
]
output
=
h
[
-
1
]
f_no_opt
=
pytensor
.
function
([
input1
,
input2
,
input3
],
output
,
mode
=
no_opt_mode
)
f_no_opt
=
pytensor
.
function
([
input1
,
input2
,
input3
],
output
,
mode
=
no_opt_mode
)
...
@@ -892,13 +941,20 @@ class TestScanMerge:
...
@@ -892,13 +941,20 @@ class TestScanMerge:
"""
"""
inps
=
vector
()
inps
=
vector
()
state
=
scalar
()
state
=
scalar
()
y1
,
_
=
scan
(
lambda
x
,
y
:
x
*
y
,
sequences
=
inps
,
outputs_info
=
state
,
n_steps
=
5
)
y1
=
scan
(
lambda
x
,
y
:
x
*
y
,
sequences
=
inps
,
outputs_info
=
state
,
n_steps
=
5
,
return_updates
=
False
,
)
y2
,
_
=
scan
(
y2
=
scan
(
lambda
x
,
y
:
(
x
+
y
,
until
(
x
>
0
)),
lambda
x
,
y
:
(
x
+
y
,
until
(
x
>
0
)),
sequences
=
inps
,
sequences
=
inps
,
outputs_info
=
state
,
outputs_info
=
state
,
n_steps
=
5
,
n_steps
=
5
,
return_updates
=
False
,
)
)
scan_node1
=
y1
.
owner
.
inputs
[
0
]
.
owner
scan_node1
=
y1
.
owner
.
inputs
[
0
]
.
owner
assert
isinstance
(
scan_node1
.
op
,
Scan
)
assert
isinstance
(
scan_node1
.
op
,
Scan
)
...
@@ -958,8 +1014,8 @@ class TestScanMerge:
...
@@ -958,8 +1014,8 @@ class TestScanMerge:
def
sub
(
s1
,
s2
,
const
):
def
sub
(
s1
,
s2
,
const
):
return
s1
-
1
,
until
(
s2
>
const
)
return
s1
-
1
,
until
(
s2
>
const
)
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
]
)
sx
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
],
return_updates
=
False
)
sy
,
_
=
scan
(
sub
,
sequences
=
[
y
,
-
z
],
non_sequences
=
[
c1
]
)
sy
=
scan
(
sub
,
sequences
=
[
y
,
-
z
],
non_sequences
=
[
c1
],
return_updates
=
False
)
f
=
pytensor
.
function
(
inputs
=
[
x
,
y
,
z
,
c1
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
)
f
=
pytensor
.
function
(
inputs
=
[
x
,
y
,
z
,
c1
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
2
assert
self
.
count_scans
(
f
)
==
2
...
@@ -972,8 +1028,8 @@ class TestScanMerge:
...
@@ -972,8 +1028,8 @@ class TestScanMerge:
np
.
testing
.
assert_array_equal
(
res_sx
,
[
1
,
1
])
np
.
testing
.
assert_array_equal
(
res_sx
,
[
1
,
1
])
np
.
testing
.
assert_array_equal
(
res_sy
,
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
])
np
.
testing
.
assert_array_equal
(
res_sy
,
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
])
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
]
)
sx
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
],
return_updates
=
False
)
sy
,
_
=
scan
(
sub
,
sequences
=
[
y
,
z
],
non_sequences
=
[
c2
]
)
sy
=
scan
(
sub
,
sequences
=
[
y
,
z
],
non_sequences
=
[
c2
],
return_updates
=
False
)
f
=
pytensor
.
function
(
f
=
pytensor
.
function
(
inputs
=
[
x
,
y
,
z
,
c1
,
c2
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
inputs
=
[
x
,
y
,
z
,
c1
,
c2
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
...
@@ -989,22 +1045,23 @@ class TestScanMerge:
...
@@ -989,22 +1045,23 @@ class TestScanMerge:
np
.
testing
.
assert_array_equal
(
res_sx
,
[
1
,
1
,
1
,
1
,
1
])
np
.
testing
.
assert_array_equal
(
res_sx
,
[
1
,
1
,
1
,
1
,
1
])
np
.
testing
.
assert_array_equal
(
res_sy
,
[
-
1
,
-
1
,
-
1
])
np
.
testing
.
assert_array_equal
(
res_sy
,
[
-
1
,
-
1
,
-
1
])
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
]
)
sx
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
],
return_updates
=
False
)
sy
,
_
=
scan
(
sub
,
sequences
=
[
y
,
z
],
non_sequences
=
[
c1
]
)
sy
=
scan
(
sub
,
sequences
=
[
y
,
z
],
non_sequences
=
[
c1
],
return_updates
=
False
)
f
=
pytensor
.
function
(
inputs
=
[
x
,
y
,
z
,
c1
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
)
f
=
pytensor
.
function
(
inputs
=
[
x
,
y
,
z
,
c1
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
1
assert
self
.
count_scans
(
f
)
==
1
def
nested_scan
(
c
,
x
,
z
):
def
nested_scan
(
c
,
x
,
z
):
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c
]
)
sx
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c
],
return_updates
=
False
)
sy
,
_
=
scan
(
sub
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c
]
)
sy
=
scan
(
sub
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c
],
return_updates
=
False
)
return
sx
.
sum
()
+
sy
.
sum
()
return
sx
.
sum
()
+
sy
.
sum
()
sz
,
_
=
scan
(
sz
=
scan
(
nested_scan
,
nested_scan
,
sequences
=
[
stack
([
c1
,
c2
])],
sequences
=
[
stack
([
c1
,
c2
])],
non_sequences
=
[
x
,
z
],
non_sequences
=
[
x
,
z
],
mode
=
self
.
mode
,
mode
=
self
.
mode
,
return_updates
=
False
,
)
)
f
=
pytensor
.
function
(
inputs
=
[
x
,
z
,
c1
,
c2
],
outputs
=
sz
,
mode
=
mode
)
f
=
pytensor
.
function
(
inputs
=
[
x
,
z
,
c1
,
c2
],
outputs
=
sz
,
mode
=
mode
)
...
@@ -1023,9 +1080,8 @@ class TestScanInplaceOptimizer:
...
@@ -1023,9 +1080,8 @@ class TestScanInplaceOptimizer:
x
=
pt
.
vector
(
"x"
)
x
=
pt
.
vector
(
"x"
)
scan_out
,
_
=
pytensor
.
scan
(
scan_out
=
pytensor
.
scan
(
lambda
x
:
(
x
+
1
)
/
2
+
1
,
lambda
x
:
(
x
+
1
)
/
2
+
1
,
sequences
=
[
x
],
return_updates
=
False
sequences
=
[
x
],
)
)
fgraph
=
FunctionGraph
(
fgraph
=
FunctionGraph
(
...
@@ -1039,10 +1095,8 @@ class TestScanInplaceOptimizer:
...
@@ -1039,10 +1095,8 @@ class TestScanInplaceOptimizer:
assert
equal_computations
([
scan_out
],
fgraph
.
outputs
)
assert
equal_computations
([
scan_out
],
fgraph
.
outputs
)
def
test_inplace_basic
(
self
):
def
test_inplace_basic
(
self
):
scan_out
,
_
=
pytensor
.
scan
(
scan_out
=
pytensor
.
scan
(
lambda
x
:
x
+
1
,
lambda
x
:
x
+
1
,
outputs_info
=
[
pt
.
zeros
(
1
)],
n_steps
=
3
,
return_updates
=
False
outputs_info
=
[
pt
.
zeros
(
1
)],
n_steps
=
3
,
)
)
fgraph
=
FunctionGraph
(
fgraph
=
FunctionGraph
(
...
@@ -1089,7 +1143,7 @@ class TestScanInplaceOptimizer:
...
@@ -1089,7 +1143,7 @@ class TestScanInplaceOptimizer:
u0_t
*
W_in
+
x1_tm1
*
W
+
u1_t
+
u2_t
,
u0_t
*
W_in
+
x1_tm1
*
W
+
u1_t
+
u2_t
,
]
]
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_shared
,
f_rnn_shared
,
[
u0
,
u1
,
u2
],
[
u0
,
u1
,
u2
],
[
dict
(
initial
=
x0
,
inplace
=
u2
),
dict
(
initial
=
x1
,
inplace
=
u1
)],
[
dict
(
initial
=
x0
,
inplace
=
u2
),
dict
(
initial
=
x1
,
inplace
=
u1
)],
...
@@ -1098,12 +1152,12 @@ class TestScanInplaceOptimizer:
...
@@ -1098,12 +1152,12 @@ class TestScanInplaceOptimizer:
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
mode
=
self
.
mode
,
mode
=
self
.
mode
,
return_updates
=
False
,
)
)
f9
=
function
(
f9
=
function
(
[
mu0
,
mu1
,
mu2
,
x0
,
x1
],
[
mu0
,
mu1
,
mu2
,
x0
,
x1
],
outputs
,
outputs
,
updates
=
updates
,
mode
=
self
.
mode
,
mode
=
self
.
mode
,
allow_input_downcast
=
True
,
allow_input_downcast
=
True
,
)
)
...
@@ -1155,7 +1209,7 @@ class TestScanInplaceOptimizer:
...
@@ -1155,7 +1209,7 @@ class TestScanInplaceOptimizer:
u0_t
*
W_in
+
x1_tm1
*
W
+
u2_tm1
+
u2_t
+
u2_tp1
,
u0_t
*
W_in
+
x1_tm1
*
W
+
u2_tm1
+
u2_t
+
u2_tp1
,
]
]
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_shared
,
f_rnn_shared
,
[
u0
,
dict
(
input
=
u1
,
taps
=
[
0
,
1
]),
dict
(
input
=
u2
,
taps
=
[
-
1
,
0
,
+
1
])],
[
u0
,
dict
(
input
=
u1
,
taps
=
[
0
,
1
]),
dict
(
input
=
u2
,
taps
=
[
-
1
,
0
,
+
1
])],
[
dict
(
initial
=
x0
),
dict
(
initial
=
x1
)],
[
dict
(
initial
=
x0
),
dict
(
initial
=
x1
)],
...
@@ -1164,11 +1218,11 @@ class TestScanInplaceOptimizer:
...
@@ -1164,11 +1218,11 @@ class TestScanInplaceOptimizer:
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
mode
=
self
.
mode
,
mode
=
self
.
mode
,
return_updates
=
False
,
)
)
f9
=
function
(
f9
=
function
(
[
mu0
,
mu1
,
mu2
,
x0
,
x1
],
[
mu0
,
mu1
,
mu2
,
x0
,
x1
],
outputs
,
outputs
,
updates
=
updates
,
mode
=
self
.
mode
,
mode
=
self
.
mode
,
allow_input_downcast
=
True
,
allow_input_downcast
=
True
,
)
)
...
@@ -1202,8 +1256,12 @@ class TestScanInplaceOptimizer:
...
@@ -1202,8 +1256,12 @@ class TestScanInplaceOptimizer:
vx1
=
asarrayX
(
rng
.
uniform
())
vx1
=
asarrayX
(
rng
.
uniform
())
x0
=
shared
(
vx0
)
x0
=
shared
(
vx0
)
x1
=
shared
(
vx1
)
x1
=
shared
(
vx1
)
outputs
,
updates
=
scan
(
outputs
=
scan
(
lambda
x
,
y
:
(
x
+
asarrayX
(
1
),
y
+
asarrayX
(
1
)),
[],
[
x0
,
x1
],
n_steps
=
3
lambda
x
,
y
:
(
x
+
asarrayX
(
1
),
y
+
asarrayX
(
1
)),
[],
[
x0
,
x1
],
n_steps
=
3
,
return_updates
=
False
,
)
)
x0
=
asarrayX
(
np
.
zeros
((
4
,)))
x0
=
asarrayX
(
np
.
zeros
((
4
,)))
x0
[
0
]
=
vx0
x0
[
0
]
=
vx0
...
@@ -1212,7 +1270,7 @@ class TestScanInplaceOptimizer:
...
@@ -1212,7 +1270,7 @@ class TestScanInplaceOptimizer:
to_replace
=
outputs
[
0
]
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
to_replace
=
outputs
[
0
]
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
outputs
=
clone_replace
(
outputs
,
replace
=
[(
to_replace
,
x0
)])
outputs
=
clone_replace
(
outputs
,
replace
=
[(
to_replace
,
x0
)])
f9
=
function
([],
outputs
,
updates
=
updates
,
mode
=
self
.
mode
)
f9
=
function
([],
outputs
,
mode
=
self
.
mode
)
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
assert
0
not
in
scan_node
[
0
]
.
op
.
destroy_map
assert
0
not
in
scan_node
[
0
]
.
op
.
destroy_map
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
...
@@ -1249,7 +1307,7 @@ class TestSaveMem:
...
@@ -1249,7 +1307,7 @@ class TestSaveMem:
y_tm1
+
dot
(
x_tm1
,
W_out
),
y_tm1
+
dot
(
x_tm1
,
W_out
),
]
]
_outputs
,
update
s
=
scan
(
out
s
=
scan
(
f_rnn_cmpl
,
f_rnn_cmpl
,
[
u1
,
u2
],
[
u1
,
u2
],
[
None
,
dict
(
initial
=
x0
),
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
])],
[
None
,
dict
(
initial
=
x0
),
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
])],
...
@@ -1257,12 +1315,12 @@ class TestSaveMem:
...
@@ -1257,12 +1315,12 @@ class TestSaveMem:
n_steps
=
None
,
n_steps
=
None
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
outputs
=
[
_outputs
[
0
][
-
1
],
_outputs
[
1
][
-
1
],
_outp
uts
[
2
][
-
1
]]
outputs
=
[
outs
[
0
][
-
1
],
outs
[
1
][
-
1
],
o
uts
[
2
][
-
1
]]
f4
=
function
(
f4
=
function
(
[
u1
,
u2
,
x0
,
y0
,
W_in1
],
[
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
,
allow_input_downcast
=
True
,
mode
=
self
.
mode
,
mode
=
self
.
mode
,
)
)
...
@@ -1297,14 +1355,18 @@ class TestSaveMem:
...
@@ -1297,14 +1355,18 @@ class TestSaveMem:
u
=
vector
(
"u"
)
u
=
vector
(
"u"
)
idx
=
iscalar
(
"idx"
)
idx
=
iscalar
(
"idx"
)
jdx
=
iscalar
(
"jdx"
)
jdx
=
iscalar
(
"jdx"
)
[
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
],
updates
=
scan
(
[
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
]
=
scan
(
f_rnn
,
u
,
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
f_rnn
,
u
,
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
f2
=
function
(
f2
=
function
(
[
u
,
idx
,
jdx
],
[
u
,
idx
,
jdx
],
[
x1
[:
2
],
x2
[
4
],
x3
[
idx
],
x4
[:
idx
],
x5
[
-
10
],
x6
[
-
jdx
],
x7
[:
-
jdx
]],
[
x1
[:
2
],
x2
[
4
],
x3
[
idx
],
x4
[:
idx
],
x5
[
-
10
],
x6
[
-
jdx
],
x7
[:
-
jdx
]],
updates
=
updates
,
allow_input_downcast
=
True
,
allow_input_downcast
=
True
,
mode
=
self
.
mode
.
excluding
(
"scan_push_out_seq"
),
mode
=
self
.
mode
.
excluding
(
"scan_push_out_seq"
),
)
)
...
@@ -1341,10 +1403,8 @@ class TestSaveMem:
...
@@ -1341,10 +1403,8 @@ class TestSaveMem:
def
test_save_mem_reduced_number_of_steps_constant
(
self
):
def
test_save_mem_reduced_number_of_steps_constant
(
self
):
x0
=
pt
.
scalar
(
"x0"
)
x0
=
pt
.
scalar
(
"x0"
)
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm1
:
xtm1
+
1
,
lambda
xtm1
:
xtm1
+
1
,
outputs_info
=
[
x0
],
n_steps
=
10
,
return_updates
=
False
outputs_info
=
[
x0
],
n_steps
=
10
,
)
)
fn
=
function
([
x0
],
xs
[:
5
],
mode
=
self
.
mode
)
fn
=
function
([
x0
],
xs
[:
5
],
mode
=
self
.
mode
)
...
@@ -1358,10 +1418,11 @@ class TestSaveMem:
...
@@ -1358,10 +1418,11 @@ class TestSaveMem:
def
test_save_mem_cannot_reduce_constant_number_of_steps
(
self
):
def
test_save_mem_cannot_reduce_constant_number_of_steps
(
self
):
x0
=
pt
.
scalar
(
"x0"
)
x0
=
pt
.
scalar
(
"x0"
)
[
xs
,
ys
]
,
_
=
scan
(
[
xs
,
ys
]
=
scan
(
lambda
xtm1
,
ytm1
:
(
xtm1
+
1
,
ytm1
-
1
),
lambda
xtm1
,
ytm1
:
(
xtm1
+
1
,
ytm1
-
1
),
outputs_info
=
[
x0
,
x0
],
outputs_info
=
[
x0
,
x0
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
# Because of ys[-1] we need all the steps!
# Because of ys[-1] we need all the steps!
...
@@ -1399,7 +1460,7 @@ class TestSaveMem:
...
@@ -1399,7 +1460,7 @@ class TestSaveMem:
x20
=
scalar
(
"x20"
)
x20
=
scalar
(
"x20"
)
x30
=
vector
(
"x30"
)
x30
=
vector
(
"x30"
)
x40
=
scalar
(
"x40"
)
x40
=
scalar
(
"x40"
)
[
x1
,
x2
,
x3
,
x4
,
x5
,
_x6
,
_x7
]
,
updates
=
scan
(
[
x1
,
x2
,
x3
,
x4
,
x5
,
_x6
,
_x7
]
=
scan
(
step
,
step
,
u
,
u
,
[
[
...
@@ -1414,12 +1475,12 @@ class TestSaveMem:
...
@@ -1414,12 +1475,12 @@ class TestSaveMem:
n_steps
=
None
,
n_steps
=
None
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
f
=
function
(
f
=
function
(
[
u
,
x10
,
x20
,
x30
,
x40
],
[
u
,
x10
,
x20
,
x30
,
x40
],
[
x1
[
-
7
],
x2
[
-
3
:
-
1
],
x3
[
-
6
:],
x4
[
-
1
],
x5
[
-
1
]],
[
x1
[
-
7
],
x2
[
-
3
:
-
1
],
x3
[
-
6
:],
x4
[
-
1
],
x5
[
-
1
]],
updates
=
updates
,
allow_input_downcast
=
True
,
allow_input_downcast
=
True
,
mode
=
self
.
mode
,
mode
=
self
.
mode
,
)
)
...
@@ -1479,10 +1540,11 @@ class TestSaveMem:
...
@@ -1479,10 +1540,11 @@ class TestSaveMem:
def
test_savemem_does_not_duplicate_number_of_scan_nodes
(
self
):
def
test_savemem_does_not_duplicate_number_of_scan_nodes
(
self
):
var
=
pt
.
ones
(())
var
=
pt
.
ones
(())
values
,
_
=
scan
(
values
=
scan
(
lambda
x
:
([
x
],
(),
until
(
x
)),
lambda
x
:
([
x
],
(),
until
(
x
)),
outputs_info
=
[
var
],
outputs_info
=
[
var
],
n_steps
=
2
,
n_steps
=
2
,
return_updates
=
False
,
)
)
tmp_fn
=
function
([
var
],
values
,
mode
=
self
.
mode
)
tmp_fn
=
function
([
var
],
values
,
mode
=
self
.
mode
)
...
@@ -1493,10 +1555,11 @@ class TestSaveMem:
...
@@ -1493,10 +1555,11 @@ class TestSaveMem:
def
test_savemem_opt
(
self
,
benchmark
):
def
test_savemem_opt
(
self
,
benchmark
):
y0
=
shared
(
np
.
ones
((
2
,
10
)))
y0
=
shared
(
np
.
ones
((
2
,
10
)))
[
_y1
,
y2
]
,
_updates
=
scan
(
[
_y1
,
y2
]
=
scan
(
lambda
y
:
[
y
,
y
],
lambda
y
:
[
y
,
y
],
outputs_info
=
[
dict
(
initial
=
y0
,
taps
=
[
-
2
]),
None
],
outputs_info
=
[
dict
(
initial
=
y0
,
taps
=
[
-
2
]),
None
],
n_steps
=
5
,
n_steps
=
5
,
return_updates
=
False
,
)
)
# TODO FIXME: Make this a real test and assert something.
# TODO FIXME: Make this a real test and assert something.
fn
=
function
([],
y2
.
sum
(),
mode
=
self
.
mode
)
fn
=
function
([],
y2
.
sum
(),
mode
=
self
.
mode
)
...
@@ -1515,23 +1578,25 @@ class TestSaveMem:
...
@@ -1515,23 +1578,25 @@ class TestSaveMem:
return
dot
(
h_tm1
,
w
)
+
x_t_t
return
dot
(
h_tm1
,
w
)
+
x_t_t
def
outer_scan_step
(
x_t
,
w
):
def
outer_scan_step
(
x_t
,
w
):
h
,
_
=
scan
(
h
=
scan
(
inner_scan_step
,
inner_scan_step
,
sequences
=
[
x_t
[
1
:]],
sequences
=
[
x_t
[
1
:]],
outputs_info
=
[
x_t
[
0
]],
outputs_info
=
[
x_t
[
0
]],
non_sequences
=
[
w
],
non_sequences
=
[
w
],
strict
=
True
,
strict
=
True
,
name
=
"the_inner_scan"
,
name
=
"the_inner_scan"
,
return_updates
=
False
,
)
)
return
h
return
h
def
get_outputs
(
x
,
w
):
def
get_outputs
(
x
,
w
):
features
,
_
=
scan
(
features
=
scan
(
outer_scan_step
,
outer_scan_step
,
sequences
=
[
x
],
sequences
=
[
x
],
non_sequences
=
[
w
],
non_sequences
=
[
w
],
strict
=
True
,
strict
=
True
,
name
=
"the_outer_scan"
,
name
=
"the_outer_scan"
,
return_updates
=
False
,
)
)
return_val
=
grad
(
features
.
sum
(),
w
)
return_val
=
grad
(
features
.
sum
(),
w
)
...
@@ -1571,7 +1636,7 @@ class TestSaveMem:
...
@@ -1571,7 +1636,7 @@ class TestSaveMem:
state
=
vector
(
"state"
)
state
=
vector
(
"state"
)
n_steps
=
iscalar
(
"nsteps"
)
n_steps
=
iscalar
(
"nsteps"
)
output
,
updates
=
scan
(
output
=
scan
(
f_pow2
,
f_pow2
,
[],
[],
state
,
state
,
...
@@ -1579,13 +1644,13 @@ class TestSaveMem:
...
@@ -1579,13 +1644,13 @@ class TestSaveMem:
n_steps
=
n_steps
,
n_steps
=
n_steps
,
truncate_gradient
=-
1
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
go_backwards
=
False
,
return_updates
=
False
,
)
)
nw_shape
=
ivector
(
"nw_shape"
)
nw_shape
=
ivector
(
"nw_shape"
)
# Note that the output is reshaped to 3 dimensional tensor, and
# Note that the output is reshaped to 3 dimensional tensor, and
my_f
=
function
(
my_f
=
function
(
[
state
,
n_steps
,
nw_shape
],
[
state
,
n_steps
,
nw_shape
],
[
reshape
(
output
,
nw_shape
,
ndim
=
3
)[:
-
2
],
output
[:
-
4
]],
[
reshape
(
output
,
nw_shape
,
ndim
=
3
)[:
-
2
],
output
[:
-
4
]],
updates
=
updates
,
allow_input_downcast
=
True
,
allow_input_downcast
=
True
,
)
)
nodes
=
[
x
for
x
in
my_f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
nodes
=
[
x
for
x
in
my_f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
...
@@ -1599,11 +1664,12 @@ class TestSaveMem:
...
@@ -1599,11 +1664,12 @@ class TestSaveMem:
n_steps
=
scalar
(
"n_steps"
,
dtype
=
"int64"
)
n_steps
=
scalar
(
"n_steps"
,
dtype
=
"int64"
)
x0
=
vector
(
"x0"
)
x0
=
vector
(
"x0"
)
ys
,
_
=
pytensor
.
scan
(
ys
=
pytensor
.
scan
(
# Fibonacci Sequence
# Fibonacci Sequence
lambda
xtm2
,
xtm1
:
(
xtm1
+
xtm2
,
{},
until
(
xtm1
>=
34
)),
lambda
xtm2
,
xtm1
:
(
xtm1
+
xtm2
,
{},
until
(
xtm1
>=
34
)),
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
2
,
-
1
]}],
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
2
,
-
1
]}],
n_steps
=
n_steps
,
n_steps
=
n_steps
,
return_updates
=
False
,
)
)
# Save memory is triggered by choosing only last value
# Save memory is triggered by choosing only last value
y
=
ys
[
-
1
]
y
=
ys
[
-
1
]
...
@@ -1629,10 +1695,11 @@ class TestSaveMem:
...
@@ -1629,10 +1695,11 @@ class TestSaveMem:
def
test_while_scan_map
(
self
):
def
test_while_scan_map
(
self
):
xs
=
vector
(
"xs"
)
xs
=
vector
(
"xs"
)
ys
,
_
=
pytensor
.
scan
(
ys
=
pytensor
.
scan
(
lambda
x
:
(
x
+
1
,
{},
until
(
x
+
1
>=
10
)),
lambda
x
:
(
x
+
1
,
{},
until
(
x
+
1
>=
10
)),
outputs_info
=
[
None
],
outputs_info
=
[
None
],
sequences
=
[
xs
],
sequences
=
[
xs
],
return_updates
=
False
,
)
)
# Save memory is triggered by choosing only last value
# Save memory is triggered by choosing only last value
y
=
ys
[
-
1
]
y
=
ys
[
-
1
]
...
@@ -1656,11 +1723,12 @@ class TestSaveMem:
...
@@ -1656,11 +1723,12 @@ class TestSaveMem:
n_steps
=
scalar
(
"n_steps"
,
dtype
=
"int64"
)
n_steps
=
scalar
(
"n_steps"
,
dtype
=
"int64"
)
# while loop
# while loop
[
ys
,
zs
]
,
_
=
pytensor
.
scan
(
[
ys
,
zs
]
=
pytensor
.
scan
(
lambda
s
,
xtm1
:
((
xtm1
+
1
,
xtm1
+
1
+
s
),
{},
until
(
xtm1
>=
99
)),
lambda
s
,
xtm1
:
((
xtm1
+
1
,
xtm1
+
1
+
s
),
{},
until
(
xtm1
>=
99
)),
sequences
=
[
seq
],
sequences
=
[
seq
],
outputs_info
=
[
x0
,
None
],
outputs_info
=
[
x0
,
None
],
n_steps
=
n_steps
,
n_steps
=
n_steps
,
return_updates
=
False
,
)
)
# Save memory is triggered by choosing only last value
# Save memory is triggered by choosing only last value
y
=
ys
[
-
1
]
y
=
ys
[
-
1
]
...
@@ -1696,10 +1764,11 @@ class TestSaveMem:
...
@@ -1696,10 +1764,11 @@ class TestSaveMem:
val_test
=
np
.
zeros
(
val_shape
,
dtype
=
val
.
dtype
)
val_test
=
np
.
zeros
(
val_shape
,
dtype
=
val
.
dtype
)
init
=
pt
.
full
((
2
,),
val
)
init
=
pt
.
full
((
2
,),
val
)
ys
,
_
=
pytensor
.
scan
(
ys
=
pytensor
.
scan
(
fn
=
lambda
*
args
:
pt
.
add
(
*
args
),
fn
=
lambda
*
args
:
pt
.
add
(
*
args
),
outputs_info
=
[{
"initial"
:
init
,
"taps"
:
(
-
2
,
-
1
)}],
outputs_info
=
[{
"initial"
:
init
,
"taps"
:
(
-
2
,
-
1
)}],
n_steps
=
100
,
n_steps
=
100
,
return_updates
=
False
,
)
)
out
=
ys
[:
-
50
]
if
keep_beginning
else
ys
[
-
50
:]
out
=
ys
[:
-
50
]
if
keep_beginning
else
ys
[
-
50
:]
...
@@ -1729,12 +1798,13 @@ def test_inner_replace_dot():
...
@@ -1729,12 +1798,13 @@ def test_inner_replace_dot():
mode
=
get_default_mode
()
.
including
(
"scan"
)
# .excluding("BlasOpt")
mode
=
get_default_mode
()
.
including
(
"scan"
)
# .excluding("BlasOpt")
o
,
_
=
scan
(
o
=
scan
(
lambda
hi
,
him1
,
W
:
(
hi
,
dot
(
hi
+
him1
,
W
)),
lambda
hi
,
him1
,
W
:
(
hi
,
dot
(
hi
+
him1
,
W
)),
outputs_info
=
[
pt
.
zeros
([
h
.
shape
[
1
]]),
None
],
outputs_info
=
[
pt
.
zeros
([
h
.
shape
[
1
]]),
None
],
sequences
=
[
h
],
sequences
=
[
h
],
non_sequences
=
[
W
],
non_sequences
=
[
W
],
mode
=
mode
,
mode
=
mode
,
return_updates
=
False
,
)
)
f
=
function
([
W
,
h
],
o
,
mode
=
mode
)
f
=
function
([
W
,
h
],
o
,
mode
=
mode
)
...
@@ -1753,11 +1823,12 @@ def test_alloc_inputs1():
...
@@ -1753,11 +1823,12 @@ def test_alloc_inputs1():
def
lambda_fn
(
h
,
W1
,
W2
):
def
lambda_fn
(
h
,
W1
,
W2
):
return
dot
(
h
,
W1
*
W2
)
return
dot
(
h
,
W1
*
W2
)
o
,
_
=
scan
(
o
=
scan
(
lambda_fn
,
lambda_fn
,
outputs_info
=
h0
,
outputs_info
=
h0
,
non_sequences
=
[
W1
,
pt
.
zeros_like
(
W2
)],
non_sequences
=
[
W1
,
pt
.
zeros_like
(
W2
)],
n_steps
=
5
,
n_steps
=
5
,
return_updates
=
False
,
)
)
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
get_default_mode
()
.
including
(
"scan"
))
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
get_default_mode
()
.
including
(
"scan"
))
...
@@ -1786,12 +1857,13 @@ def test_alloc_inputs2():
...
@@ -1786,12 +1857,13 @@ def test_alloc_inputs2():
def
lambda_fn
(
W1
,
h
,
W2
):
def
lambda_fn
(
W1
,
h
,
W2
):
return
W1
*
dot
(
h
,
W2
)
return
W1
*
dot
(
h
,
W2
)
o
,
_
=
scan
(
o
=
scan
(
lambda_fn
,
lambda_fn
,
sequences
=
pt
.
zeros_like
(
W1
),
sequences
=
pt
.
zeros_like
(
W1
),
outputs_info
=
h0
,
outputs_info
=
h0
,
non_sequences
=
[
pt
.
zeros_like
(
W2
)],
non_sequences
=
[
pt
.
zeros_like
(
W2
)],
n_steps
=
5
,
n_steps
=
5
,
return_updates
=
False
,
)
)
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
get_default_mode
()
.
including
(
"scan"
))
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
get_default_mode
()
.
including
(
"scan"
))
...
@@ -1821,12 +1893,13 @@ def test_alloc_inputs3():
...
@@ -1821,12 +1893,13 @@ def test_alloc_inputs3():
def
lambda_fn
(
W1
,
h
,
W2
):
def
lambda_fn
(
W1
,
h
,
W2
):
return
W1
*
dot
(
h
,
W2
)
return
W1
*
dot
(
h
,
W2
)
o
,
_
=
scan
(
o
=
scan
(
lambda_fn
,
lambda_fn
,
sequences
=
pt
.
zeros_like
(
W1
),
sequences
=
pt
.
zeros_like
(
W1
),
outputs_info
=
h0
,
outputs_info
=
h0
,
non_sequences
=
[
pt
.
zeros_like
(
W2
)],
non_sequences
=
[
pt
.
zeros_like
(
W2
)],
n_steps
=
5
,
n_steps
=
5
,
return_updates
=
False
,
)
)
# TODO FIXME: This result depends on unrelated rewrites in the "fast" mode.
# TODO FIXME: This result depends on unrelated rewrites in the "fast" mode.
...
@@ -1848,7 +1921,7 @@ def test_opt_order():
...
@@ -1848,7 +1921,7 @@ def test_opt_order():
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
A
=
matrix
(
"A"
)
A
=
matrix
(
"A"
)
z
,
_updates
=
scan
(
dot
,
sequences
=
[],
non_sequences
=
[
x
,
A
],
n_steps
=
2
)
z
=
scan
(
dot
,
sequences
=
[],
non_sequences
=
[
x
,
A
],
n_steps
=
2
,
return_updates
=
False
)
f
=
function
([
x
,
A
],
z
,
mode
=
"FAST_RUN"
)
f
=
function
([
x
,
A
],
z
,
mode
=
"FAST_RUN"
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
topo
=
f
.
maker
.
fgraph
.
toposort
()
...
...
tests/tensor/linalg/test_rewriting.py
浏览文件 @
abedb7fb
...
@@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
...
@@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
A
=
tensor
(
"A"
,
shape
=
(
3
,
3
))
A
=
tensor
(
"A"
,
shape
=
(
3
,
3
))
x0
=
tensor
(
"b"
,
shape
=
(
3
,
4
))
x0
=
tensor
(
"b"
,
shape
=
(
3
,
4
))
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm1
,
A
:
solve
(
A
,
xtm1
,
assume_a
=
assume_a
,
transposed
=
transposed
),
lambda
xtm1
,
A
:
solve
(
A
,
xtm1
,
assume_a
=
assume_a
,
transposed
=
transposed
),
outputs_info
=
[
x0
],
outputs_info
=
[
x0
],
non_sequences
=
[
A
],
non_sequences
=
[
A
],
n_steps
=
10
,
n_steps
=
10
,
return_updates
=
False
,
)
)
fn_no_opt
=
function
(
fn_no_opt
=
function
(
...
...
tests/tensor/test_blockwise.py
浏览文件 @
abedb7fb
...
@@ -694,10 +694,11 @@ def test_blockwise_grad_core_type():
...
@@ -694,10 +694,11 @@ def test_blockwise_grad_core_type():
def
test_scan_gradient_core_type
():
def
test_scan_gradient_core_type
():
n_steps
=
3
n_steps
=
3
seq
=
tensor
(
"seq"
,
shape
=
(
n_steps
,
1
),
dtype
=
"float64"
)
seq
=
tensor
(
"seq"
,
shape
=
(
n_steps
,
1
),
dtype
=
"float64"
)
out
,
_
=
scan
(
out
=
scan
(
lambda
s
:
s
,
lambda
s
:
s
,
sequences
=
[
seq
],
sequences
=
[
seq
],
n_steps
=
n_steps
,
n_steps
=
n_steps
,
return_updates
=
False
,
)
)
vec_seq
=
tensor
(
"vec_seq"
,
shape
=
(
None
,
n_steps
,
1
),
dtype
=
"float64"
)
vec_seq
=
tensor
(
"vec_seq"
,
shape
=
(
None
,
n_steps
,
1
),
dtype
=
"float64"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论