Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
abedb7fb
提交
abedb7fb
authored
10月 31, 2025
作者:
ricardoV94
提交者:
Ricardo Vieira
11月 08, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Start using new API in tests that don't involve shared updates
上级
78293400
隐藏空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
505 行增加
和
266 行删除
+505
-266
test_scan.py
tests/link/jax/test_scan.py
+40
-20
test_scan.py
tests/link/numba/test_scan.py
+29
-13
test_basic.py
tests/scan/test_basic.py
+270
-142
test_rewriting.py
tests/scan/test_rewriting.py
+162
-89
test_rewriting.py
tests/tensor/linalg/test_rewriting.py
+2
-1
test_blockwise.py
tests/tensor/test_blockwise.py
+2
-1
没有找到文件。
tests/link/jax/test_scan.py
浏览文件 @
abedb7fb
...
...
@@ -23,10 +23,11 @@ jax = pytest.importorskip("jax")
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
2
,
None
,
None
)])
def
test_scan_sit_sot
(
view
):
x0
=
pt
.
scalar
(
"x0"
,
dtype
=
"float64"
)
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm1
:
xtm1
+
1
,
outputs_info
=
[
x0
],
n_steps
=
10
,
return_updates
=
False
,
)
if
view
:
xs
=
xs
[
view
]
...
...
@@ -37,10 +38,11 @@ def test_scan_sit_sot(view):
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
def
test_scan_mit_sot
(
view
):
x0
=
pt
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
3
,))
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm3
,
xtm1
:
xtm3
+
xtm1
+
1
,
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
n_steps
=
10
,
return_updates
=
False
,
)
if
view
:
xs
=
xs
[
view
]
...
...
@@ -57,13 +59,14 @@ def test_scan_multiple_mit_sot(view_x, view_y):
def
step
(
xtm3
,
xtm1
,
ytm4
,
ytm2
):
return
xtm3
+
ytm4
+
1
,
xtm1
+
ytm2
+
2
[
xs
,
ys
]
,
_
=
scan
(
[
xs
,
ys
]
=
scan
(
fn
=
step
,
outputs_info
=
[
{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]},
{
"initial"
:
y0
,
"taps"
:
[
-
4
,
-
2
]},
],
n_steps
=
10
,
return_updates
=
False
,
)
if
view_x
:
xs
=
xs
[
view_x
]
...
...
@@ -80,10 +83,8 @@ def test_scan_nit_sot(view):
xs
=
pt
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
10
,))
ys
,
_
=
scan
(
lambda
x
:
pt
.
exp
(
x
),
outputs_info
=
[
None
],
sequences
=
[
xs
],
ys
=
scan
(
lambda
x
:
pt
.
exp
(
x
),
outputs_info
=
[
None
],
sequences
=
[
xs
],
return_updates
=
False
)
if
view
:
ys
=
ys
[
view
]
...
...
@@ -106,11 +107,12 @@ def test_scan_mit_mot():
rho
=
pt
.
scalar
(
"rho"
,
dtype
=
"float64"
)
x0
=
pt
.
vector
(
"xs"
,
shape
=
(
2
,))
y0
=
pt
.
vector
(
"ys"
,
shape
=
(
3
,))
[
outs
,
_
]
,
_
=
scan
(
[
outs
,
_
]
=
scan
(
step
,
outputs_info
=
[
x0
,
{
"initial"
:
y0
,
"taps"
:
[
-
3
,
-
1
]}],
non_sequences
=
[
rho
],
n_steps
=
10
,
return_updates
=
False
,
)
grads
=
pt
.
grad
(
outs
.
sum
(),
wrt
=
[
x0
,
y0
,
rho
])
compare_jax_and_py
(
...
...
@@ -191,10 +193,11 @@ def test_scan_rng_update():
@pytest.mark.xfail
(
raises
=
NotImplementedError
)
def
test_scan_while
():
xs
,
_
=
scan
(
xs
=
scan
(
lambda
x
:
(
x
+
1
,
until
(
x
<
10
)),
outputs_info
=
[
pt
.
zeros
(())],
n_steps
=
100
,
return_updates
=
False
,
)
compare_jax_and_py
([],
[
xs
],
[])
...
...
@@ -210,7 +213,7 @@ def test_scan_mitsot_with_nonseq():
res
.
name
=
"y_t"
return
res
y_scan_pt
,
_
=
scan
(
y_scan_pt
=
scan
(
fn
=
input_step_fn
,
outputs_info
=
[
{
...
...
@@ -223,6 +226,7 @@ def test_scan_mitsot_with_nonseq():
non_sequences
=
[
a_pt
],
n_steps
=
10
,
name
=
"y_scan"
,
return_updates
=
False
,
)
y_scan_pt
.
name
=
"y"
y_scan_pt
.
owner
.
inputs
[
0
]
.
name
=
"y_all"
...
...
@@ -241,11 +245,12 @@ def test_nd_scan_sit_sot(x0_func, A_func):
k
=
3
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
lambda
X
,
A
:
A
@
X
,
non_sequences
=
[
A
],
outputs_info
=
[
x0
],
n_steps
=
n_steps
,
return_updates
=
False
,
)
x0_val
=
(
...
...
@@ -267,11 +272,12 @@ def test_nd_scan_sit_sot_with_seq():
A
=
pt
.
matrix
(
"A"
,
shape
=
(
k
,
k
))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
lambda
X
,
A
:
A
@
X
,
non_sequences
=
[
A
],
sequences
=
[
x
],
n_steps
=
n_steps
,
return_updates
=
False
,
)
x_val
=
np
.
arange
(
n_steps
*
k
,
dtype
=
config
.
floatX
)
.
reshape
(
n_steps
,
k
)
...
...
@@ -287,11 +293,12 @@ def test_nd_scan_mit_sot():
B
=
pt
.
matrix
(
"B"
,
shape
=
(
3
,
3
))
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm3
,
xtm1
,
A
,
B
:
A
@
xtm3
+
B
@
xtm1
,
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
non_sequences
=
[
A
,
B
],
n_steps
=
10
,
return_updates
=
False
,
)
x0_val
=
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
)
...
...
@@ -310,12 +317,13 @@ def test_nd_scan_sit_sot_with_carry():
return
A
@
x
,
x
.
sum
()
# Must specify mode = JAX for the inner func to avoid a GEMM Op in the JAX graph
xs
,
_
=
scan
(
xs
=
scan
(
step
,
outputs_info
=
[
x0
,
None
],
non_sequences
=
[
A
],
n_steps
=
10
,
mode
=
get_mode
(
"JAX"
),
return_updates
=
False
,
)
x0_val
=
np
.
arange
(
3
,
dtype
=
config
.
floatX
)
...
...
@@ -329,7 +337,13 @@ def test_default_mode_excludes_incompatible_rewrites():
# See issue #426
A
=
matrix
(
"A"
)
B
=
matrix
(
"B"
)
out
,
_
=
scan
(
lambda
a
,
b
:
a
@
b
,
outputs_info
=
[
A
],
non_sequences
=
[
B
],
n_steps
=
2
)
out
=
scan
(
lambda
a
,
b
:
a
@
b
,
outputs_info
=
[
A
],
non_sequences
=
[
B
],
n_steps
=
2
,
return_updates
=
False
,
)
compare_jax_and_py
([
A
,
B
],
[
out
],
[
np
.
eye
(
3
),
np
.
eye
(
3
)],
jax_mode
=
"JAX"
)
...
...
@@ -353,8 +367,11 @@ def test_dynamic_sequence_length():
x
=
pt
.
tensor
(
"x"
,
shape
=
(
None
,
3
))
out
,
_
=
scan
(
lambda
x
:
inc_without_static_shape
(
x
),
outputs_info
=
[
None
],
sequences
=
[
x
]
out
=
scan
(
lambda
x
:
inc_without_static_shape
(
x
),
outputs_info
=
[
None
],
sequences
=
[
x
],
return_updates
=
False
,
)
f
=
function
([
x
],
out
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
assert
sum
(
isinstance
(
node
.
op
,
Scan
)
for
node
in
f
.
maker
.
fgraph
.
apply_nodes
)
==
1
...
...
@@ -364,10 +381,11 @@ def test_dynamic_sequence_length():
np
.
testing
.
assert_allclose
(
f
(
np
.
zeros
((
0
,
3
))),
np
.
empty
((
0
,
3
)))
# With known static shape we should always manage, regardless of the internal implementation
out2
,
_
=
scan
(
out2
=
scan
(
lambda
x
:
pt
.
specify_shape
(
inc_without_static_shape
(
x
),
x
.
shape
),
outputs_info
=
[
None
],
sequences
=
[
x
],
return_updates
=
False
,
)
f2
=
function
([
x
],
out2
,
mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan"
))
np
.
testing
.
assert_allclose
(
f2
([[
1
,
2
,
3
]]),
np
.
array
([[
2
,
3
,
4
]]))
...
...
@@ -418,11 +436,12 @@ def SEIR_model_logp():
it1
=
it0
+
ct0
-
dt0
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
,
_
=
scan
(
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
=
scan
(
fn
=
seir_one_step
,
sequences
=
[
C_t
,
D_t
],
outputs_info
=
[
st0
,
et0
,
it0
,
None
,
None
],
non_sequences
=
[
beta
,
gamma
,
delta
],
return_updates
=
False
,
)
st
.
name
=
"S_t"
et
.
name
=
"E_t"
...
...
@@ -511,11 +530,12 @@ def cyclical_reduction():
max_iter
=
100
tol
=
1e-7
(
*
_
,
A1_hat
,
norm
,
_n_steps
)
,
_
=
scan
(
(
*
_
,
A1_hat
,
norm
,
_n_steps
)
=
scan
(
step
,
outputs_info
=
[
A
,
B
,
C
,
B
,
norm
,
step_num
],
non_sequences
=
[
tol
],
n_steps
=
max_iter
,
return_updates
=
False
,
)
A1_hat
=
A1_hat
[
-
1
]
...
...
tests/link/numba/test_scan.py
浏览文件 @
abedb7fb
...
...
@@ -206,11 +206,12 @@ def test_scan_multiple_output(benchmark):
it1
=
it0
+
ct0
-
dt0
return
st1
,
et1
,
it1
,
logp_c1
,
logp_d1
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
,
_
=
scan
(
(
st
,
et
,
it
,
logp_c_all
,
logp_d_all
)
=
scan
(
fn
=
seir_one_step
,
sequences
=
[
pt_C
,
pt_D
],
outputs_info
=
[
st0
,
et0
,
it0
,
logp_c
,
logp_d
],
non_sequences
=
[
beta
,
gamma
,
delta
],
return_updates
=
False
,
)
st
.
name
=
"S_t"
et
.
name
=
"E_t"
...
...
@@ -268,7 +269,7 @@ def test_scan_tap_output():
y_t
.
name
=
"y_t"
return
x_t
,
y_t
,
pt
.
fill
((
10
,),
z_t
)
scan_res
,
_
=
scan
(
scan_res
=
scan
(
fn
=
input_step_fn
,
sequences
=
[
{
...
...
@@ -297,6 +298,7 @@ def test_scan_tap_output():
n_steps
=
5
,
name
=
"yz_scan"
,
strict
=
True
,
return_updates
=
False
,
)
test_input_vals
=
[
...
...
@@ -312,11 +314,12 @@ def test_scan_while():
return
previous_power
*
2
,
until
(
previous_power
*
2
>
max_value
)
max_value
=
pt
.
scalar
()
values
,
_
=
scan
(
values
=
scan
(
power_of_2
,
outputs_info
=
pt
.
constant
(
1.0
),
non_sequences
=
max_value
,
n_steps
=
1024
,
return_updates
=
False
,
)
test_input_vals
=
[
...
...
@@ -331,11 +334,12 @@ def test_scan_multiple_none_output():
def
power_step
(
prior_result
,
x
):
return
prior_result
*
x
,
prior_result
*
x
*
x
,
prior_result
*
x
*
x
*
x
result
,
_
=
scan
(
result
=
scan
(
power_step
,
non_sequences
=
[
A
],
outputs_info
=
[
pt
.
ones_like
(
A
),
None
,
None
],
n_steps
=
3
,
return_updates
=
False
,
)
test_input_vals
=
(
np
.
array
([
1.0
,
2.0
]),)
compare_numba_and_py
([
A
],
result
,
test_input_vals
)
...
...
@@ -343,8 +347,12 @@ def test_scan_multiple_none_output():
def
test_grad_sitsot
():
def
get_sum_of_grad
(
inp
):
scan_outputs
,
_updates
=
scan
(
fn
=
lambda
x
:
x
*
2
,
outputs_info
=
[
inp
],
n_steps
=
5
,
mode
=
"NUMBA"
scan_outputs
=
scan
(
fn
=
lambda
x
:
x
*
2
,
outputs_info
=
[
inp
],
n_steps
=
5
,
mode
=
"NUMBA"
,
return_updates
=
False
,
)
return
grad
(
scan_outputs
.
sum
(),
inp
)
.
sum
()
...
...
@@ -362,8 +370,11 @@ def test_mitmots_basic():
def
inner_fct
(
seq
,
state_old
,
state_current
):
return
state_old
*
2
+
state_current
+
seq
out
,
_
=
scan
(
inner_fct
,
sequences
=
seq
,
outputs_info
=
{
"initial"
:
init_x
,
"taps"
:
[
-
2
,
-
1
]}
out
=
scan
(
inner_fct
,
sequences
=
seq
,
outputs_info
=
{
"initial"
:
init_x
,
"taps"
:
[
-
2
,
-
1
]},
return_updates
=
False
,
)
g_outs
=
grad
(
out
.
sum
(),
[
seq
,
init_x
])
...
...
@@ -383,10 +394,11 @@ def test_mitmots_basic():
def
test_inner_graph_optimized
():
"""Test that inner graph of Scan is optimized"""
xs
=
vector
(
"xs"
)
seq
,
_
=
scan
(
seq
=
scan
(
fn
=
lambda
x
:
log
(
1
+
x
),
sequences
=
[
xs
],
mode
=
get_mode
(
"NUMBA"
),
return_updates
=
False
,
)
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
...
...
@@ -421,13 +433,14 @@ def test_vector_taps_benchmark(benchmark):
sitsot2
=
(
sitsot1
+
mitsot3
)
/
np
.
sqrt
(
2
)
return
mitsot3
,
sitsot2
outs
,
_
=
scan
(
outs
=
scan
(
fn
=
step
,
sequences
=
[
seq1
,
seq2
],
outputs_info
=
[
dict
(
initial
=
mitsot_init
,
taps
=
[
-
2
,
-
1
]),
dict
(
initial
=
sitsot_init
,
taps
=
[
-
1
]),
],
return_updates
=
False
,
)
rng
=
np
.
random
.
default_rng
(
474
)
...
...
@@ -468,7 +481,7 @@ def test_inplace_taps(n_steps_constant):
y
=
ytm1
+
1
+
ytm2
+
a
return
z
,
x
,
z
+
x
+
y
,
y
[
zs
,
xs
,
ws
,
ys
]
,
_
=
scan
(
[
zs
,
xs
,
ws
,
ys
]
=
scan
(
fn
=
step
,
outputs_info
=
[
dict
(
initial
=
z0
,
taps
=
[
-
3
,
-
1
]),
...
...
@@ -478,6 +491,7 @@ def test_inplace_taps(n_steps_constant):
],
non_sequences
=
[
a
],
n_steps
=
n_steps
,
return_updates
=
False
,
)
numba_fn
,
_
=
compare_numba_and_py
(
[
n_steps
]
*
(
not
n_steps_constant
)
+
[
a
,
x0
,
y0
,
z0
],
...
...
@@ -529,10 +543,11 @@ def test_inplace_taps(n_steps_constant):
class
TestScanSITSOTBuffer
:
def
buffer_tester
(
self
,
n_steps
,
op_size
,
buffer_size
,
benchmark
=
None
):
x0
=
pt
.
vector
(
shape
=
(
op_size
,),
dtype
=
"float64"
)
xs
,
_
=
pytensor
.
scan
(
xs
=
pytensor
.
scan
(
fn
=
lambda
xtm1
:
(
xtm1
+
1
),
outputs_info
=
[
x0
],
n_steps
=
n_steps
-
1
,
# 1- makes it easier to align/misalign
return_updates
=
False
,
)
if
buffer_size
==
"unit"
:
xs_kept
=
xs
[
-
1
]
# Only last state is used
...
...
@@ -588,12 +603,13 @@ class TestScanMITSOTBuffer:
init_x
=
pt
.
vector
(
"init_x"
,
shape
=
(
2
,))
n_steps
=
pt
.
iscalar
(
"n_steps"
)
output
,
_
=
scan
(
output
=
scan
(
f_pow2
,
sequences
=
[],
outputs_info
=
[{
"initial"
:
init_x
,
"taps"
:
[
-
2
,
-
1
]}],
non_sequences
=
[],
n_steps
=
n_steps_val
if
constant_n_steps
else
n_steps
,
return_updates
=
False
,
)
init_x_val
=
np
.
array
([
1.0
,
2.0
],
dtype
=
init_x
.
type
.
dtype
)
...
...
tests/scan/test_basic.py
浏览文件 @
abedb7fb
...
...
@@ -294,7 +294,7 @@ class TestScan:
def
test_clone
(
self
):
a
=
vector
()
output
,
_
=
scan
(
fn
=
lambda
x
:
x
**
2
,
sequences
=
[
a
]
)
output
=
scan
(
fn
=
lambda
x
:
x
**
2
,
sequences
=
[
a
],
return_updates
=
False
)
scan_op
=
output
.
owner
.
op
assert
isinstance
(
scan_op
,
Scan
)
...
...
@@ -320,7 +320,7 @@ class TestScan:
state
=
scalar
(
"state"
)
n_steps
=
iscalar
(
"nsteps"
)
output
,
updates
=
scan
(
output
=
scan
(
f_pow2
,
[],
state
,
...
...
@@ -328,10 +328,9 @@ class TestScan:
n_steps
=
n_steps
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
_my_f
=
function
(
[
state
,
n_steps
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
_my_f
=
function
([
state
,
n_steps
],
output
,
allow_input_downcast
=
True
)
origdir
=
Path
.
cwd
()
tmpdir
=
None
...
...
@@ -368,11 +367,9 @@ class TestScan:
state
=
scalar
(
"state"
)
n_steps
=
iscalar
(
"nsteps"
)
output
,
updates
=
scan
(
f_pow2
,
[],
state
,
[],
n_steps
=
n_steps
)
output
=
scan
(
f_pow2
,
[],
state
,
[],
n_steps
=
n_steps
,
return_updates
=
False
)
f
=
function
(
[
state
,
n_steps
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
f
=
function
([
state
,
n_steps
],
output
,
allow_input_downcast
=
True
)
scan_node
=
[
node
for
node
in
f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
node
.
op
,
Scan
)
...
...
@@ -410,7 +407,9 @@ class TestScan:
return
2
*
x_tm1
n_steps
=
iscalar
(
"n_steps"
)
values
,
_
=
scan
(
f_pow
,
outputs_info
=
(
x_init
,),
n_steps
=
n_steps
)
values
=
scan
(
f_pow
,
outputs_info
=
(
x_init
,),
n_steps
=
n_steps
,
return_updates
=
False
)
update_fn
=
function
((
x_init
,
n_steps
),
values
,
mode
=
mode
)
...
...
@@ -443,7 +442,9 @@ class TestScan:
return
2
*
x_i
with
config
.
change_flags
(
mode
=
mode
):
values
,
_
=
scan
(
inner_fn
,
outputs_info
=
(
x_init
,),
sequences
=
x
)
values
=
scan
(
inner_fn
,
outputs_info
=
(
x_init
,),
sequences
=
x
,
return_updates
=
False
)
values_fn
=
function
((
x_init
,
x
),
values
)
assert
isinstance
(
values
.
owner
.
inputs
[
0
]
.
owner
.
op
,
Scan
)
...
...
@@ -474,7 +475,7 @@ class TestScan:
return
2
*
x_i
with
config
.
change_flags
(
mode
=
mode
):
values
,
_
=
scan
(
inner_fn
,
sequences
=
x
)
values
=
scan
(
inner_fn
,
sequences
=
x
,
return_updates
=
False
)
values_fn
=
function
((
x
,),
values
)
assert
isinstance
(
values
.
owner
.
op
,
Scan
)
...
...
@@ -491,7 +492,9 @@ class TestScan:
# Compile the PyTensor function
n_steps
=
2
inp
=
matrix
()
broadcasted_inp
,
_
=
scan
(
lambda
x
:
x
,
non_sequences
=
[
inp
],
n_steps
=
n_steps
)
broadcasted_inp
=
scan
(
lambda
x
:
x
,
non_sequences
=
[
inp
],
n_steps
=
n_steps
,
return_updates
=
False
)
out
=
broadcasted_inp
.
sum
()
gr
=
grad
(
out
,
inp
)
fun
=
function
([
inp
],
[
broadcasted_inp
,
gr
])
...
...
@@ -519,7 +522,7 @@ class TestScan:
W_in
=
scalar
(
"win"
)
W
=
scalar
(
"w"
)
output
,
updates
=
scan
(
output
=
scan
(
f_rnn
,
u
,
x0
,
...
...
@@ -527,11 +530,10 @@ class TestScan:
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
f2
=
function
(
[
u
,
x0
,
W_in
,
W
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
f2
=
function
([
u
,
x0
,
W_in
,
W
],
output
,
allow_input_downcast
=
True
)
# get random initial values
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
4
,))
...
...
@@ -561,7 +563,7 @@ class TestScan:
def
f_rnn_shared
(
u_t
,
x_tm1
,
tmp_W_in
,
tmp_W
):
return
u_t
*
tmp_W_in
+
x_tm1
*
tmp_W
output
,
updates
=
scan
(
output
=
scan
(
f_rnn_shared
,
u
,
x0
,
...
...
@@ -569,8 +571,9 @@ class TestScan:
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
f3
=
function
([
u
,
x0
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
f3
=
function
([
u
,
x0
],
output
,
allow_input_downcast
=
True
)
# get random initial values
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
4
,))
...
...
@@ -688,11 +691,14 @@ class TestScan:
# this test refers to a bug reported by Nicolas
# Boulanger-Lewandowski June 6th
x
=
dvector
()
y
,
updates
=
scan
(
lambda
x
:
[
x
],
sequences
=
dict
(
input
=
x
,
taps
=
[
-
1
]),
outputs_info
=
[
None
]
y
=
scan
(
lambda
x
:
[
x
],
sequences
=
dict
(
input
=
x
,
taps
=
[
-
1
]),
outputs_info
=
[
None
],
return_updates
=
False
,
)
inp
=
np
.
arange
(
5
)
.
astype
(
"float64"
)
rval
=
function
([
x
],
y
,
updates
=
updates
)(
inp
)
rval
=
function
([
x
],
y
)(
inp
)
assert
np
.
all
(
rval
==
inp
[:
-
1
])
def
test_output_only
(
self
):
...
...
@@ -701,11 +707,18 @@ class TestScan:
u
=
vector
(
"u"
)
outputs
,
updates
=
scan
(
f_rnn
,
u
,
[],
[],
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
outputs
=
scan
(
f_rnn
,
u
,
[],
[],
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
f2
=
function
([
u
],
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
)
f2
=
function
([
u
],
outputs
,
allow_input_downcast
=
True
)
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
5
,))
...
...
@@ -722,7 +735,7 @@ class TestScan:
W_in
=
scalar
(
"win"
)
W
=
scalar
(
"w"
)
output
,
updates
=
scan
(
output
=
scan
(
f_rnn
,
u
,
x0
,
...
...
@@ -730,11 +743,10 @@ class TestScan:
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
True
,
return_updates
=
False
,
)
f2
=
function
(
[
u
,
x0
,
W_in
,
W
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
f2
=
function
([
u
,
x0
,
W_in
,
W
],
output
,
allow_input_downcast
=
True
)
# get random initial values
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
v_u
=
rng
.
uniform
(
-
5.0
,
5.0
,
size
=
(
4
,))
...
...
@@ -797,8 +809,8 @@ class TestScan:
def
test_hash
(
self
):
x
=
vector
()
y
=
vector
()
scan1
,
_updates
=
scan
(
lambda
_x
:
_x
+
1
,
x
)
scan2
,
_updates
=
scan
(
lambda
_x
:
_x
+
1
,
y
)
scan1
=
scan
(
lambda
_x
:
_x
+
1
,
x
,
return_updates
=
False
)
scan2
=
scan
(
lambda
_x
:
_x
+
1
,
y
,
return_updates
=
False
)
assert
scan1
.
owner
.
op
==
scan2
.
owner
.
op
assert
hash
(
scan1
.
owner
.
op
)
==
hash
(
scan2
.
owner
.
op
)
...
...
@@ -809,9 +821,24 @@ class TestScan:
y
=
vector
(
"y"
)
c
=
scalar
(
"c"
)
scan_a
,
_
=
scan
(
lambda
x
,
y
,
c
:
x
+
y
+
c
,
sequences
=
[
x
,
y
],
non_sequences
=
[
c
])
scan_b
,
_
=
scan
(
lambda
x
,
y
,
c
:
x
+
y
+
c
,
sequences
=
[
x
,
y
],
non_sequences
=
[
c
])
scan_c
,
_
=
scan
(
lambda
x
,
y
,
c
:
x
+
y
+
c
,
sequences
=
[
y
,
x
],
non_sequences
=
[
c
])
scan_a
=
scan
(
lambda
x
,
y
,
c
:
x
+
y
+
c
,
sequences
=
[
x
,
y
],
non_sequences
=
[
c
],
return_updates
=
False
,
)
scan_b
=
scan
(
lambda
x
,
y
,
c
:
x
+
y
+
c
,
sequences
=
[
x
,
y
],
non_sequences
=
[
c
],
return_updates
=
False
,
)
scan_c
=
scan
(
lambda
x
,
y
,
c
:
x
+
y
+
c
,
sequences
=
[
y
,
x
],
non_sequences
=
[
c
],
return_updates
=
False
,
)
assert
scan_b
is
not
scan_a
assert
scan_c
is
not
scan_a
...
...
@@ -1006,7 +1033,7 @@ class TestScan:
def
lambda_fn
(
x_t
):
return
x_t
+
1
,
until
(
x_t
>
3
)
o
,
_
=
scan
(
lambda_fn
,
x
)
o
=
scan
(
lambda_fn
,
x
,
return_updates
=
False
)
f
=
function
([
x
],
o
)
vx
=
np
.
zeros
((
50
,),
dtype
=
config
.
floatX
)
vx
[
23
]
=
4
...
...
@@ -1019,7 +1046,7 @@ class TestScan:
def
lambda_fn
(
x_t
):
return
x_t
+
1
,
until
(
x_t
>
3
)
o
,
_
=
scan
(
lambda_fn
,
x
)
o
=
scan
(
lambda_fn
,
x
,
return_updates
=
False
)
f
=
function
([
x
],
o
.
shape
[
0
],
mode
=
mode_with_opt
)
vx
=
np
.
zeros
((
50
,),
dtype
=
config
.
floatX
)
...
...
@@ -1029,11 +1056,12 @@ class TestScan:
def
test_infer_shape_nsteps_smaller_seq_length
(
self
):
x
=
vector
(
"x"
)
[
o1
,
o2
]
,
_
=
scan
(
[
o1
,
o2
]
=
scan
(
lambda
x
,
y
:
(
x
+
1
,
y
+
x
),
sequences
=
x
,
outputs_info
=
[
None
,
x
[
0
]],
n_steps
=
20
,
return_updates
=
False
,
)
f
=
function
([
x
],
[
o1
.
shape
[
0
],
o2
.
shape
[
0
]],
mode
=
mode_with_opt
)
...
...
@@ -1071,17 +1099,18 @@ class TestScan:
mode
=
MonitorMode
(
post_func
=
detect_large_outputs
)
# Symbolic description of the result
result
,
updates
=
scan
(
result
=
scan
(
fn
=
lambda
prior_result
,
A
:
prior_result
*
A
,
outputs_info
=
pt
.
ones_like
(
A
),
non_sequences
=
A
,
n_steps
=
k
,
mode
=
mode
,
return_updates
=
False
,
)
final_result
=
result
[
-
1
]
f
=
function
(
inputs
=
[
A
,
k
],
outputs
=
final_result
,
updates
=
updates
)
f
=
function
(
inputs
=
[
A
,
k
],
outputs
=
final_result
)
f
(
np
.
asarray
([
2
,
3
,
0.1
,
0
,
1
],
dtype
=
config
.
floatX
),
4
)
# There should be 3 outputs greater than 10: prior_result[0] at step 3,
...
...
@@ -1103,10 +1132,11 @@ class TestScan:
y
.
name
=
"y"
gy
=
grad
(
y
,
x
)
gy
.
name
=
"gy"
hy
,
_updates
=
scan
(
hy
=
scan
(
lambda
i
,
gy
,
x
:
grad
(
gy
[
i
]
*
fc2
,
x
),
sequences
=
pt
.
arange
(
gy
.
shape
[
0
]),
non_sequences
=
[
gy
,
x
],
return_updates
=
False
,
)
f
=
function
([
x
,
A
],
hy
,
allow_input_downcast
=
True
)
...
...
@@ -1123,8 +1153,13 @@ class TestScan:
def
test_sequence_is_scan
(
self
,
mode
):
"""Make sure that a `Scan` can be used as a sequence input to another `Scan`."""
x0
=
scalar
(
"x0"
)
scan_1
,
_
=
scan
(
lambda
x
:
x
+
1
,
outputs_info
=
{
"initial"
:
x0
},
n_steps
=
10
)
scan_2
,
_
=
scan
(
lambda
x
:
x
+
1
,
sequences
=
[
scan_1
])
scan_1
=
scan
(
lambda
x
:
x
+
1
,
outputs_info
=
{
"initial"
:
x0
},
n_steps
=
10
,
return_updates
=
False
,
)
scan_2
=
scan
(
lambda
x
:
x
+
1
,
sequences
=
[
scan_1
],
return_updates
=
False
)
with
config
.
change_flags
(
mode
=
mode
):
scan_2_fn
=
function
([
x0
],
scan_2
)
...
...
@@ -1185,7 +1220,7 @@ class TestScan:
def
test_blockwise_scan
(
self
):
x
=
pt
.
tensor
(
"x"
,
shape
=
())
out
,
_
=
scan
(
lambda
x
:
x
+
1
,
outputs_info
=
[
x
],
n_steps
=
10
)
out
=
scan
(
lambda
x
:
x
+
1
,
outputs_info
=
[
x
],
n_steps
=
10
,
return_updates
=
False
)
x_vec
=
pt
.
tensor
(
"x_vec"
,
shape
=
(
None
,))
out_vec
=
vectorize_graph
(
out
,
{
x
:
x_vec
})
...
...
@@ -1203,13 +1238,14 @@ class TestScan:
a0
=
shared
(
np
.
arange
(
2
))
b0
=
shared
(
np
.
arange
(
2
))
(
a
,
_b
)
,
_
=
scan
(
(
a
,
_b
)
=
scan
(
fn
,
outputs_info
=
[
{
"initial"
:
a0
,
"taps"
:
[
-
2
,
-
1
]},
{
"initial"
:
b0
,
"taps"
:
[
-
2
,
-
1
]},
],
n_steps
=
2
,
return_updates
=
False
,
)
grad
(
a
[
-
1
],
a0
)
...
...
@@ -1241,8 +1277,11 @@ class TestScan:
state_next
=
state_old
*
2
+
state_current
+
seq
return
state_next
out
,
_
=
scan
(
inner_fct
,
sequences
=
seq
,
outputs_info
=
{
"initial"
:
x
,
"taps"
:
[
-
2
,
-
1
]}
out
=
scan
(
inner_fct
,
sequences
=
seq
,
outputs_info
=
{
"initial"
:
x
,
"taps"
:
[
-
2
,
-
1
]},
return_updates
=
False
,
)
g_out
=
grad
(
out
.
sum
(),
[
seq
,
x
])
...
...
@@ -1302,12 +1341,13 @@ class TestScan:
new_y
=
pt
.
switch
(
cond
,
y
,
sigmoid
(
x
))
return
new_cond
,
new_x
,
new_y
values
,
_
=
scan
(
values
=
scan
(
inner_fn
,
outputs_info
=
[
c
,
x
,
y
],
n_steps
=
10
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
gX
,
gY
=
grad
(
values
[
1
]
.
sum
(),
[
x
,
y
])
f
=
function
([
c
,
x
,
y
],
[
gX
,
gY
],
allow_input_downcast
=
True
)
...
...
@@ -1762,11 +1802,12 @@ class TestScan:
outputs_info
=
[
None
,
dict
(
initial
=
out_init
,
taps
=
[
-
3
])]
scan_outputs
,
_
=
scan
(
scan_outputs
=
scan
(
fn
=
inner_fct
,
sequences
=
seq
,
outputs_info
=
outputs_info
,
non_sequences
=
non_seq
,
return_updates
=
False
,
)
# Attempt to take various gradients
...
...
@@ -1834,7 +1875,9 @@ class TestScan:
dict
(
initial
=
out_init
[
3
],
taps
=
[
-
2
,
-
1
]),
]
scan_outputs
,
_
=
scan
(
fn
=
inner_fct
,
outputs_info
=
outputs_info
,
n_steps
=
10
)
scan_outputs
=
scan
(
fn
=
inner_fct
,
outputs_info
=
outputs_info
,
n_steps
=
10
,
return_updates
=
False
)
grad
(
scan_outputs
[
0
]
.
sum
(),
out_init
[
1
])
...
...
@@ -1857,11 +1900,12 @@ class TestScan:
x
=
scalar
(
"x"
)
_max_coefficients_supported
=
1000
full_range
=
pt
.
arange
(
_max_coefficients_supported
)
components
,
_updates
=
scan
(
components
=
scan
(
fn
=
lambda
coeff
,
power
,
free_var
:
coeff
*
(
free_var
**
power
),
outputs_info
=
None
,
sequences
=
[
c
,
full_range
],
non_sequences
=
x
,
return_updates
=
False
,
)
P
=
components
.
sum
()
dP
=
grad
(
P
,
x
)
...
...
@@ -1877,11 +1921,12 @@ class TestScan:
x
=
scalar
(
"x"
)
_max_coefficients_supported
=
1000
full_range
=
pt
.
arange
(
_max_coefficients_supported
)
components
,
_updates
=
scan
(
components
=
scan
(
fn
=
lambda
coeff
,
power
,
free_var
:
coeff
*
(
free_var
**
power
),
outputs_info
=
None
,
sequences
=
[
c
,
full_range
],
non_sequences
=
x
,
return_updates
=
False
,
)
P
=
components
.
sum
()
dP
=
grad
(
P
,
x
)
.
sum
()
...
...
@@ -1968,8 +2013,13 @@ class TestScan:
_W
=
specify_shape
(
W
,
v_W
.
shape
)
_W
.
name
=
"_W"
o
,
_
=
scan
(
rnn_fn
,
sequences
=
_u
,
outputs_info
=
_h0
,
non_sequences
=
_W
,
name
=
"rnn_fn"
o
=
scan
(
rnn_fn
,
sequences
=
_u
,
outputs_info
=
_h0
,
non_sequences
=
_W
,
name
=
"rnn_fn"
,
return_updates
=
False
,
)
o
=
o
[
-
1
]
eu
=
matrix
(
"eu"
)
...
...
@@ -1983,25 +2033,28 @@ class TestScan:
[
u
,
h0
,
W
,
eu
,
eh0
,
eW
],
[
nwo_u
,
nwo_h0
,
nwo_W
],
on_unused_input
=
"ignore"
)
n2o_u
,
_
=
scan
(
n2o_u
=
scan
(
lambda
i
,
o
,
u
,
h0
,
W
,
eu
:
(
grad
(
o
[
i
],
u
)
*
eu
)
.
sum
(),
sequences
=
pt
.
arange
(
o
.
shape
[
0
]),
non_sequences
=
[
o
,
u
,
h0
,
W
,
eu
],
name
=
"jacobU"
,
return_updates
=
False
,
)
n2o_h0
,
_
=
scan
(
n2o_h0
=
scan
(
lambda
i
,
o
,
u
,
h0
,
W
,
eh0
:
(
grad
(
o
[
i
],
h0
)
*
eh0
)
.
sum
(),
sequences
=
pt
.
arange
(
o
.
shape
[
0
]),
non_sequences
=
[
o
,
u
,
h0
,
W
,
eh0
],
name
=
"jacobh"
,
return_updates
=
False
,
)
n2o_W
,
_
=
scan
(
n2o_W
=
scan
(
lambda
i
,
o
,
u
,
h0
,
W
,
eW
:
(
grad
(
o
[
i
],
W
)
*
eW
)
.
sum
(),
sequences
=
pt
.
arange
(
o
.
shape
[
0
]),
non_sequences
=
[
o
,
u
,
h0
,
W
,
eW
],
name
=
"jacobW"
,
return_updates
=
False
,
)
fn_test
=
function
(
...
...
@@ -2132,10 +2185,11 @@ class TestScan:
transfer
=
sigmoid
hidden_rec
,
_
=
scan
(
hidden_rec
=
scan
(
lambda
x
,
h_tm1
:
transfer
(
dot
(
h_tm1
,
W2
)
+
x
),
sequences
=
hidden
,
outputs_info
=
[
pt
.
zeros_like
(
hidden
[
0
])],
return_updates
=
False
,
)
hidden_rec
.
reshape
(
...
...
@@ -2168,12 +2222,13 @@ class TestScan:
def
step
(
s
,
xtm2
,
xtm1
,
z
):
return
s
*
((
xtm2
*
0
+
xtm1
)
**
2
)
*
(
z
/
2
)
xs
,
_
=
scan
(
xs
=
scan
(
step
,
sequences
=
[
seq
],
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
(
-
2
,
-
1
)}],
non_sequences
=
[
z
],
n_steps
=
2
,
return_updates
=
False
,
)
last_x
=
xs
[
-
1
]
...
...
@@ -2254,11 +2309,12 @@ class TestScan:
raise
ValueError
(
f
"Invalid case: {case}"
)
seq
=
vector
(
"seq"
)
xs
,
_
=
scan
(
xs
=
scan
(
step
,
sequences
=
[
seq
],
non_sequences
=
non_sequences
,
strict
=
strict
,
return_updates
=
False
,
)
x0
=
xs
[
0
]
...
...
@@ -2298,7 +2354,7 @@ def test_cvm_exception_handling(mode):
def
scan_fn
():
return
myop
(
pt
.
as_tensor
(
1
))
res
,
_
=
scan
(
scan_fn
,
n_steps
=
4
,
mode
=
mod
e
)
res
=
scan
(
scan_fn
,
n_steps
=
4
,
mode
=
mode
,
return_updates
=
Fals
e
)
res_fn
=
function
([],
res
,
mode
=
mode
)
...
...
@@ -2328,14 +2384,14 @@ def test_cython_performance(benchmark):
py_res
=
f_py
()
s_r
=
pt
.
as_tensor_variable
(
r
,
dtype
=
config
.
floatX
)
s_y
,
updates
=
scan
(
s_y
=
scan
(
fn
=
lambda
ri
,
rii
,
M
:
ri
+
M
*
rii
,
sequences
=
[
s_r
[
1
:]],
non_sequences
=
[
pt
.
as_tensor_variable
(
M
,
dtype
=
config
.
floatX
)],
outputs_info
=
s_r
[
0
],
mode
=
Mode
(
linker
=
"cvm"
,
optimizer
=
"fast_run"
),
return_updates
=
False
,
)
assert
not
updates
f_cvm
=
function
([],
s_y
,
mode
=
"FAST_RUN"
)
f_cvm
.
trust_input
=
True
...
...
@@ -2357,9 +2413,7 @@ def test_compute_test_values():
y
=
shared
(
np
.
arange
(
3
,
dtype
=
config
.
floatX
),
name
=
"y"
)
z
,
updates
=
scan
(
fn
=
lambda
u
,
v
:
u
+
v
,
sequences
=
[
x
,
y
])
assert
not
updates
z
=
scan
(
fn
=
lambda
u
,
v
:
u
+
v
,
sequences
=
[
x
,
y
],
return_updates
=
False
)
z_grad
=
grad
(
z
.
sum
(),
x
)
...
...
@@ -2368,9 +2422,9 @@ def test_compute_test_values():
# Use `non_sequences` this time
y
=
shared
(
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
),
name
=
"y"
)
z
,
updates
=
scan
(
fn
=
lambda
u
,
v
:
u
+
v
,
sequences
=
[
x
],
non_sequences
=
[
y
])
assert
not
updates
z
=
scan
(
fn
=
lambda
u
,
v
:
u
+
v
,
sequences
=
[
x
],
non_sequences
=
[
y
],
return_updates
=
False
)
z_grad
=
grad
(
z
.
sum
(),
x
)
...
...
@@ -2399,20 +2453,22 @@ def test_compute_test_value_grad():
def
loss_ti
(
ti
,
sum_ti
,
mi
,
W
):
return
W
.
sum
()
.
sum
()
.
sum
()
+
sum_ti
result_ti
,
_
=
scan
(
result_ti
=
scan
(
fn
=
loss_ti
,
outputs_info
=
outputs_ti
,
sequences
=
pt
.
arange
(
W
.
shape
[
1
],
dtype
=
"int32"
),
non_sequences
=
[
mi
,
W
],
return_updates
=
False
,
)
lossmi
=
result_ti
[
-
1
]
return
sum_mi
+
lossmi
result_mi
,
_
=
scan
(
result_mi
=
scan
(
fn
=
loss_mi
,
outputs_info
=
outputs_mi
,
sequences
=
pt
.
arange
(
W
.
shape
[
0
],
dtype
=
"int32"
),
non_sequences
=
[
W
],
return_updates
=
False
,
)
loss
=
result_mi
[
-
1
]
...
...
@@ -2436,11 +2492,12 @@ def test_compute_test_value_grad_cast():
name
=
"w"
,
)
outputs
,
_
=
scan
(
outputs
=
scan
(
lambda
i
,
h
,
w
:
(
dot
(
h
[
i
],
w
),
i
),
outputs_info
=
[
None
,
0
],
non_sequences
=
[
h
,
w
],
n_steps
=
3
,
return_updates
=
False
,
)
grad
(
outputs
[
0
]
.
sum
(),
w
)
...
...
@@ -2449,11 +2506,12 @@ def test_compute_test_value_grad_cast():
def
test_constant_folding_n_steps
():
# The following code used to crash at revision 2060b8f, in the constant
# folding optimization step.
res
,
_
=
scan
(
res
=
scan
(
lambda
x
:
x
*
2
,
outputs_info
=
pt
.
ones
(()),
# The constant `n_steps` was causing the crash.
n_steps
=
10
,
return_updates
=
False
,
)
with
config
.
change_flags
(
on_opt_error
=
"raise"
):
function
([],
res
)()
...
...
@@ -2478,10 +2536,11 @@ def test_outputs_taps_check():
def
test_inconsistent_broadcast_error
():
x
=
tensor3
()
initial_x
=
pt
.
constant
(
np
.
zeros
((
1
,
10
)))
y
,
_updates
=
scan
(
y
=
scan
(
fn
=
lambda
x
,
prev_x
:
x
+
prev_x
,
sequences
=
x
,
outputs_info
=
[
dict
(
initial
=
initial_x
)],
return_updates
=
False
,
)
# Error, because the broadcast patterns are inconsistent.
with
pytest
.
raises
(
TypeError
):
...
...
@@ -2509,10 +2568,11 @@ class TestGradUntil:
self
.
numpy_gradient
=
2
*
np
.
concatenate
([
self
.
seq
[:
7
],
z
],
axis
=
0
)
def
test_grad_until
(
self
):
r
,
_
=
scan
(
r
=
scan
(
lambda
x
,
u
:
(
x
*
x
,
until
(
x
>
u
)),
sequences
=
self
.
x
,
non_sequences
=
[
self
.
threshold
],
return_updates
=
False
,
)
g
=
grad
(
r
.
sum
(),
self
.
x
)
f
=
function
([
self
.
x
,
self
.
threshold
],
[
r
,
g
])
...
...
@@ -2528,10 +2588,11 @@ class TestGradUntil:
X
=
matrix
(
name
=
"x"
)
arr
=
tile_array
(
self
.
seq
)
r
,
_
=
scan
(
r
=
scan
(
lambda
x
,
u
:
(
x
*
x
,
until
(
pt_all
(
x
>
u
))),
sequences
=
X
,
non_sequences
=
[
self
.
threshold
],
return_updates
=
False
,
)
g
=
grad
(
r
.
sum
(),
X
)
f
=
function
([
X
,
self
.
threshold
],
[
r
,
g
])
...
...
@@ -2542,11 +2603,12 @@ class TestGradUntil:
def
test_grad_until_and_truncate
(
self
):
n
=
3
r
,
_
=
scan
(
r
=
scan
(
lambda
x
,
u
:
(
x
*
x
,
until
(
x
>
u
)),
sequences
=
self
.
x
,
non_sequences
=
[
self
.
threshold
],
truncate_gradient
=
n
,
return_updates
=
False
,
)
g
=
grad
(
r
.
sum
(),
self
.
x
)
f
=
function
([
self
.
x
,
self
.
threshold
],
[
r
,
g
])
...
...
@@ -2558,11 +2620,12 @@ class TestGradUntil:
def
test_grad_until_and_truncate_sequence_taps
(
self
):
n
=
3
r
,
_
=
scan
(
r
=
scan
(
lambda
x
,
y
,
u
:
(
x
*
y
,
until
(
y
>
u
)),
sequences
=
dict
(
input
=
self
.
x
,
taps
=
[
-
2
,
0
]),
non_sequences
=
[
self
.
threshold
],
truncate_gradient
=
n
,
return_updates
=
False
,
)
g
=
grad
(
r
.
sum
(),
self
.
x
)
f
=
function
([
self
.
x
,
self
.
threshold
],
[
r
,
g
])
...
...
@@ -2581,8 +2644,12 @@ def test_mintap_onestep():
new_sum
=
prev_sum
+
seq_t
return
new_sum
rs
,
_updates
=
scan
(
fn
=
accum
,
sequences
=
{
"input"
:
seq
,
"taps"
:
[
2
]},
outputs_info
=
0
,
n_steps
=
1
rs
=
scan
(
fn
=
accum
,
sequences
=
{
"input"
:
seq
,
"taps"
:
[
2
]},
outputs_info
=
0
,
n_steps
=
1
,
return_updates
=
False
,
)
f
=
function
(
inputs
=
[
seq
],
outputs
=
rs
)
...
...
@@ -2667,7 +2734,12 @@ def test_inner_get_vector_length():
def
test_profile_info
():
from
pytensor.scan.utils
import
ScanProfileStats
z
,
_updates
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
True
)
z
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
True
,
return_updates
=
False
,
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
fn
=
z
.
owner
.
op
.
fn
...
...
@@ -2676,8 +2748,11 @@ def test_profile_info():
assert
fn
.
profile
.
name
==
"scan_fn"
# Set the `ScanProfileStats` name
z
,
_updates
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
"profile_name"
z
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
"profile_name"
,
return_updates
=
False
,
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
...
...
@@ -2688,7 +2763,12 @@ def test_profile_info():
# Use an existing profile object
profile
=
fn
.
profile
z
,
_updates
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
profile
)
z
=
scan
(
fn
=
lambda
u
:
u
+
1
,
sequences
=
[
pt
.
arange
(
10
)],
profile
=
profile
,
return_updates
=
False
,
)
assert
isinstance
(
z
.
owner
.
op
,
Scan
)
fn
=
z
.
owner
.
op
.
fn
...
...
@@ -2819,7 +2899,7 @@ class TestExamples:
y_tm1
+
dot
(
x_tm1
,
W_out
),
]
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_cmpl
,
[
u1
,
u2
],
[
None
,
None
,
x0
,
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
])],
...
...
@@ -2827,11 +2907,10 @@ class TestExamples:
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
f4
=
function
(
[
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
)
f4
=
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
allow_input_downcast
=
True
)
# compute the values in numpy
v_x
=
np
.
zeros
((
3
,
2
),
dtype
=
config
.
floatX
)
...
...
@@ -2857,8 +2936,12 @@ class TestExamples:
def
scanStep
(
prev
,
seq
,
f1
):
return
prev
+
f1
*
seq
scanned
,
_
=
scan
(
fn
=
scanStep
,
sequences
=
[
seq
],
outputs_info
=
[
to_scan
],
non_sequences
=
[
f1
]
scanned
=
scan
(
fn
=
scanStep
,
sequences
=
[
seq
],
outputs_info
=
[
to_scan
],
non_sequences
=
[
f1
],
return_updates
=
False
,
)
function
(
inputs
=
[
to_scan
,
seq
,
f1
],
outputs
=
scanned
,
allow_input_downcast
=
True
)
...
...
@@ -2879,8 +2962,12 @@ class TestExamples:
expr
=
dot
(
h_tm1
,
W
)
+
x_t
return
expr
expr
,
_
=
scan
(
fn
=
one_step
,
sequences
=
[
inpt
],
outputs_info
=
[
initial
],
non_sequences
=
[
W
]
expr
=
scan
(
fn
=
one_step
,
sequences
=
[
inpt
],
outputs_info
=
[
initial
],
non_sequences
=
[
W
],
return_updates
=
False
,
)
v1
=
shared
(
np
.
ones
(
5
,
dtype
=
config
.
floatX
))
...
...
@@ -2917,11 +3004,12 @@ class TestExamples:
x
=
scalar
()
seq
=
vector
()
outputs_info
=
[
x
,
pt
.
zeros_like
(
x
)]
(
out1
,
out2
)
,
_updates
=
scan
(
(
out1
,
out2
)
=
scan
(
lambda
a
,
b
,
c
:
(
a
+
b
,
b
+
c
),
sequences
=
seq
,
outputs_info
=
outputs_info
,
mode
=
mode
,
return_updates
=
False
,
)
# Obtain a reference to the scan outputs before the subtensor and
...
...
@@ -2956,8 +3044,11 @@ class TestExamples:
x
=
dcol
()
seq
=
dcol
()
outputs_info
=
[
x
,
pt
.
zeros_like
(
x
)]
(
out1
,
out2
),
_updates
=
scan
(
lambda
a
,
b
,
c
:
(
a
+
b
,
a
+
c
),
sequences
=
seq
,
outputs_info
=
outputs_info
(
out1
,
out2
)
=
scan
(
lambda
a
,
b
,
c
:
(
a
+
b
,
a
+
c
),
sequences
=
seq
,
outputs_info
=
outputs_info
,
return_updates
=
False
,
)
# Obtain a reference to the scan outputs before the subtensor and
...
...
@@ -3096,7 +3187,9 @@ class TestExamples:
seq
=
matrix
()
initial_value
=
shared
(
np
.
zeros
((
4
,
1
),
dtype
=
config
.
floatX
))
outputs_info
=
[{
"initial"
:
initial_value
,
"taps"
:
[
-
4
]},
None
]
results
,
_updates
=
scan
(
fn
=
onestep
,
sequences
=
seq
,
outputs_info
=
outputs_info
)
results
=
scan
(
fn
=
onestep
,
sequences
=
seq
,
outputs_info
=
outputs_info
,
return_updates
=
False
)
f
=
function
([
seq
],
results
[
1
])
assert
np
.
all
(
exp_out
==
f
(
inp
))
...
...
@@ -3119,7 +3212,9 @@ class TestExamples:
seq
=
matrix
()
initial_value
=
shared
(
np
.
zeros
((
4
,
1
),
dtype
=
config
.
floatX
))
outputs_info
=
[{
"initial"
:
initial_value
,
"taps"
:
[
-
4
]},
None
]
results
,
_
=
scan
(
fn
=
onestep
,
sequences
=
seq
,
outputs_info
=
outputs_info
)
results
=
scan
(
fn
=
onestep
,
sequences
=
seq
,
outputs_info
=
outputs_info
,
return_updates
=
False
)
sharedvar
=
shared
(
np
.
zeros
((
1
,
1
),
dtype
=
config
.
floatX
))
updates
=
{
sharedvar
:
results
[
0
][
-
1
:]}
...
...
@@ -3164,7 +3259,7 @@ class TestExamples:
init
=
matrix
()
outputs_info
=
[
None
,
None
,
None
,
None
,
dict
(
initial
=
init
,
taps
=
[
-
3
,
-
2
,
-
1
])]
out
,
_
=
scan
(
inner_fn
,
outputs_info
=
outputs_info
,
n_steps
=
3
)
out
=
scan
(
inner_fn
,
outputs_info
=
outputs_info
,
n_steps
=
3
,
return_updates
=
False
)
fct
=
function
([
init
],
out
)
# Compare obtained outputs with expected outputs
...
...
@@ -3197,21 +3292,23 @@ class TestExamples:
def
loss_inner
(
sum_inner
,
W
):
return
sum_inner
+
(
W
**
2
)
.
sum
()
result_inner
,
_
=
scan
(
result_inner
=
scan
(
fn
=
loss_inner
,
outputs_info
=
pt
.
as_tensor_variable
(
np
.
asarray
(
0
,
dtype
=
np
.
float32
)),
non_sequences
=
[
W
],
n_steps
=
1
,
return_updates
=
False
,
)
return
sum_outer
+
result_inner
[
-
1
]
# Also test return_list for that case.
result_outer
,
_
=
scan
(
result_outer
=
scan
(
fn
=
loss_outer
,
outputs_info
=
pt
.
as_tensor_variable
(
np
.
asarray
(
0
,
dtype
=
np
.
float32
)),
non_sequences
=
[
W
],
n_steps
=
n_steps
,
return_list
=
True
,
return_updates
=
False
,
)
cost
=
result_outer
[
0
][
-
1
]
...
...
@@ -3230,7 +3327,9 @@ class TestExamples:
x0
=
vector
(
"X"
)
y0
=
vector
(
"y0"
)
z0
=
vector
(
"Z"
)
[
x
,
y
,
z
],
_
=
scan
(
inner_fn
,
outputs_info
=
[
x0
,
y0
,
z0
],
n_steps
=
10
)
[
x
,
y
,
z
]
=
scan
(
inner_fn
,
outputs_info
=
[
x0
,
y0
,
z0
],
n_steps
=
10
,
return_updates
=
False
)
cost
=
(
x
+
y
+
z
)
.
sum
()
grad
(
cost
,
x0
)
# defined
...
...
@@ -3247,7 +3346,12 @@ class TestExamples:
m
=
matrix
(
"m"
)
u0
=
pt
.
zeros
((
7
,))
[
_u
,
m2
],
_
=
scan
(
lambda
_
,
u
:
[
u
,
v
],
sequences
=
m
,
outputs_info
=
[
u0
,
None
])
[
_u
,
m2
]
=
scan
(
lambda
_
,
u
:
[
u
,
v
],
sequences
=
m
,
outputs_info
=
[
u0
,
None
],
return_updates
=
False
,
)
# This used to raise an exception with older versions because for a
# disconnected gradient a non disconnected type was returned
grad
((
m
*
m2
)
.
sum
(),
v
)
...
...
@@ -3257,8 +3361,11 @@ class TestExamples:
m
=
matrix
(
"m"
)
u0
=
pt
.
zeros
((
7
,))
[
_u
,
m2
],
_
=
scan
(
lambda
x
,
u
:
[
x
+
u
,
u
+
v
],
sequences
=
m
,
outputs_info
=
[
u0
,
None
]
[
_u
,
m2
]
=
scan
(
lambda
x
,
u
:
[
x
+
u
,
u
+
v
],
sequences
=
m
,
outputs_info
=
[
u0
,
None
],
return_updates
=
False
,
)
# This used to raise an exception with older versions because
# scan could not detect the connection between `m2` and `x`
...
...
@@ -3278,7 +3385,7 @@ class TestExamples:
out2
=
out1
+
1
return
out1
,
out2
[
_out1
,
out2
]
,
_
=
scan
(
step
,
sequences
=
v
)
[
_out1
,
out2
]
=
scan
(
step
,
sequences
=
v
,
return_updates
=
False
)
gv
=
grad
(
out2
.
sum
(),
[
v
])
f
=
function
([
v
],
gv
)
...
...
@@ -3289,7 +3396,13 @@ class TestExamples:
def
test_grad_bug_disconnected_input
(
self
):
W
=
shared
(
np
.
zeros
((
3
,
3
)),
name
=
"W"
)
v
=
ivector
(
name
=
"v"
)
y
,
_
=
scan
(
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
W
)
y
=
scan
(
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
W
,
return_updates
=
False
,
)
# This used to raise an exception
f
=
function
([
v
],
grad
(
y
.
sum
(),
W
))
...
...
@@ -3299,10 +3412,8 @@ class TestExamples:
w
=
shared
(
np
.
array
(
0
,
dtype
=
"float32"
),
name
=
"w"
)
init
=
fscalar
(
"init"
)
out
,
_
=
scan
(
fn
=
lambda
prev
:
w
,
outputs_info
=
init
,
n_steps
=
2
,
out
=
scan
(
fn
=
lambda
prev
:
w
,
outputs_info
=
init
,
n_steps
=
2
,
return_updates
=
False
)
grad
(
out
[
-
1
],
w
)
...
...
@@ -3326,7 +3437,7 @@ class TestExamples:
def
f_rnn_shared
(
u_tm2
,
x_tm1
,
x_tm2
):
return
u_tm2
*
W_in
+
x_tm1
*
W
+
x_tm2
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_shared
,
dict
(
input
=
u
,
taps
=-
2
),
dict
(
initial
=
x0
,
taps
=
[
-
1
,
-
2
]),
...
...
@@ -3334,9 +3445,10 @@ class TestExamples:
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
f7
=
function
([
u
,
x0
],
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
)
f7
=
function
([
u
,
x0
],
outputs
,
allow_input_downcast
=
True
)
pytensor_out
=
f7
(
vu
,
vx0
)
# compute output in numpy
...
...
@@ -3372,7 +3484,7 @@ class TestExamples:
def
f_rnn_shared
(
u_tm2
,
u_tp2
,
x_tm1
,
x_tm2
):
return
(
u_tm2
+
u_tp2
)
*
W_in
+
x_tm1
*
W
+
x_tm2
output
,
updates
=
scan
(
output
=
scan
(
f_rnn_shared
,
dict
(
input
=
u
,
taps
=
[
-
2
,
2
]),
dict
(
initial
=
x0
,
taps
=
[
-
1
,
-
2
]),
...
...
@@ -3380,9 +3492,10 @@ class TestExamples:
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
f8
=
function
([
u
,
x0
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
f8
=
function
([
u
,
x0
],
output
,
allow_input_downcast
=
True
)
pytensor_out
=
f8
(
vu
,
vx0
)
# compute output in numpy
numpy_out
=
np
.
zeros
(
2
)
...
...
@@ -3404,7 +3517,7 @@ class TestExamples:
state
=
scalar
(
"state"
)
n_steps
=
iscalar
(
"nsteps"
)
# Test return_list at the same time.
output
,
updates
=
scan
(
output
=
scan
(
f_pow2
,
[],
state
,
...
...
@@ -3413,10 +3526,9 @@ class TestExamples:
truncate_gradient
=-
1
,
return_list
=
True
,
go_backwards
=
False
,
return_updates
=
False
,
)
my_f
=
function
(
[
state
,
n_steps
],
output
,
updates
=
updates
,
allow_input_downcast
=
True
)
my_f
=
function
([
state
,
n_steps
],
output
,
allow_input_downcast
=
True
)
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
state
=
rng
.
uniform
()
...
...
@@ -3446,10 +3558,11 @@ class TestExamples:
pre_h
=
dot
(
x
,
W_x
)
return
pre_h
value
,
_scan_updates
=
scan
(
value
=
scan
(
_active
,
sequences
=
X
,
outputs_info
=
[
pt
.
alloc
(
floatx
(
0.0
),
1
,
out_size
)],
return_updates
=
False
,
)
cost
=
mean
(
value
)
gW_x
=
grad
(
cost
,
W_x
)
...
...
@@ -3467,7 +3580,7 @@ class TestExamples:
condition
=
until
(
new_value
>
max_value
)
return
[
new_value
,
new_step
],
condition
rs
,
_updates
=
scan
(
fn
=
accum
,
outputs_info
=
[
0
,
0
],
n_steps
=
n_steps
)
rs
=
scan
(
fn
=
accum
,
outputs_info
=
[
0
,
0
],
n_steps
=
n_steps
,
return_updates
=
False
)
f
=
function
(
inputs
=
[
max_value
,
n_steps
],
outputs
=
rs
)
...
...
@@ -3487,33 +3600,37 @@ class TestExamples:
# Generate the components of the polynomial
full_range
=
pt
.
arange
(
max_coefficients_supported
)
components
,
_updates
=
scan
(
components
=
scan
(
fn
=
lambda
coeff
,
power
,
free_var
:
coeff
*
(
free_var
**
power
),
sequences
=
[
coefficients
,
full_range
],
non_sequences
=
x
,
return_updates
=
False
,
)
polynomial1
=
components
.
sum
()
polynomial2
,
_updates
=
scan
(
polynomial2
=
scan
(
fn
=
lambda
coeff
,
power
,
prev
,
free_var
:
prev
+
coeff
*
(
free_var
**
power
),
outputs_info
=
pt
.
constant
(
0
,
dtype
=
"floatX"
),
sequences
=
[
coefficients
,
full_range
],
non_sequences
=
x
,
return_updates
=
False
,
)
# python int
polynomial3
,
_updates
=
scan
(
polynomial3
=
scan
(
fn
=
lambda
coeff
,
power
,
prev
,
free_var
:
prev
+
coeff
*
(
free_var
**
power
),
outputs_info
=
0
,
sequences
=
[
coefficients
,
full_range
],
non_sequences
=
x
,
return_updates
=
False
,
)
# python float
polynomial4
,
_updates
=
scan
(
polynomial4
=
scan
(
fn
=
lambda
coeff
,
power
,
prev
,
free_var
:
prev
+
coeff
*
(
free_var
**
power
),
outputs_info
=
0.0
,
sequences
=
[
coefficients
,
full_range
],
non_sequences
=
x
,
return_updates
=
False
,
)
calculate_polynomial
=
function
(
...
...
@@ -3576,8 +3693,12 @@ class TestExamples:
# o = v + 1 # <-- this line works
return
o
OS
,
_updates
=
scan
(
fn
=
one_step
,
sequences
=
V
,
outputs_info
=
[
None
],
non_sequences
=
[
W
]
OS
=
scan
(
fn
=
one_step
,
sequences
=
V
,
outputs_info
=
[
None
],
non_sequences
=
[
W
],
return_updates
=
False
,
)
O
=
OS
.
sum
()
+
W
.
sum
()
...
...
@@ -3591,11 +3712,12 @@ class TestExamples:
)
def
test_infershape_seq_shorter_nsteps
(
self
):
x
=
vector
(
"x"
)
[
o1
,
o2
]
,
_
=
scan
(
[
o1
,
o2
]
=
scan
(
lambda
x
,
y
:
(
x
+
1
,
y
+
x
),
sequences
=
x
,
outputs_info
=
[
None
,
x
[
0
]],
n_steps
=
20
,
return_updates
=
False
,
)
f
=
function
([
x
],
[
o1
,
o2
],
mode
=
mode_with_opt
)
...
...
@@ -3667,10 +3789,14 @@ class TestExamples:
condition
=
until
(
previous_val
>
5
)
return
new_val
,
condition
out
,
_
updates
=
scan
(
inner_fct
,
outputs_info
=
x
,
n_steps
=
10
)
out
,
updates
=
scan
(
inner_fct
,
outputs_info
=
x
,
n_steps
=
10
)
g_out
=
grad
(
out
.
sum
(),
x
)
fct
=
function
([
x
],
[
out
,
g_out
])
fct
=
function
(
[
x
],
[
out
,
g_out
],
updates
=
updates
,
)
for
i
in
range
(
-
5
,
5
):
output
,
g_output
=
fct
(
i
)
...
...
@@ -3702,7 +3828,7 @@ class TestExamples:
)
return
next_sitsot_val
,
next_mitsot_val
,
nitsot_out
out
,
_updates
=
scan
(
out
=
scan
(
fn
=
step
,
sequences
=
seq
,
outputs_info
=
[
...
...
@@ -3711,6 +3837,7 @@ class TestExamples:
None
,
],
n_steps
=
5
,
return_updates
=
False
,
)
f
=
function
([
seq
,
sitsot_init
,
mitsot_init
],
out
[
2
]
.
shape
)
...
...
@@ -3746,7 +3873,7 @@ class TestExamples:
dot
(
x_tm1
,
W_out
),
]
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_cmpl
,
[
u1
,
u2
],
[
x0
,
y0
],
...
...
@@ -3754,11 +3881,10 @@ class TestExamples:
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
f4
=
function
(
[
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
)
f4
=
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
allow_input_downcast
=
True
)
# compute the values in numpy
v_x
=
np
.
zeros
((
3
,
2
),
dtype
=
config
.
floatX
)
...
...
@@ -3802,7 +3928,7 @@ class TestExamples:
dot
(
u1_t
,
W_in1
),
]
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_cmpl
,
[
u1
,
dict
(
input
=
u2
,
taps
=
[
-
1
,
0
,
1
])],
[
x0
,
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
]),
None
],
...
...
@@ -3810,11 +3936,10 @@ class TestExamples:
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
f
=
function
(
[
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
)
f
=
function
([
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
allow_input_downcast
=
True
)
ny0
=
np
.
zeros
((
5
,
2
))
ny1
=
np
.
zeros
((
5
,))
...
...
@@ -3904,13 +4029,14 @@ class TestExamples:
return
[
h_t
,
y_t
]
# hidden and outputs of the entire sequence
[
_h
,
y
]
,
_
=
scan
(
[
_h
,
y
]
=
scan
(
fn
=
one_step
,
sequences
=
dict
(
input
=
x
),
# corresponds to the return type of one_step
outputs_info
=
[
dict
(
initial
=
h0
,
taps
=
[
-
2
,
-
1
]),
None
],
non_sequences
=
[
W_ih
,
W_hh
,
b_h
,
W_ho
,
b_o
],
mode
=
mode
,
return_updates
=
False
,
)
# target values
...
...
@@ -4084,7 +4210,7 @@ def test_output_storage_reuse(linker_mode):
outer-output arrays are initialized using the outer-input arrays, the
shape difference needs to be handled correctly.
"""
s_in_y
,
_
=
scan
(
s_in_y
=
scan
(
fn
=
lambda
z
:
(
z
+
1
,
until
(
z
>
2
)),
outputs_info
=
[
{
"taps"
:
[
-
1
],
"initial"
:
pt
.
as_tensor
(
0.0
,
dtype
=
np
.
float64
)}
...
...
@@ -4092,16 +4218,18 @@ def test_output_storage_reuse(linker_mode):
mode
=
mode
,
n_steps
=
n
-
1
,
allow_gc
=
False
,
return_updates
=
False
,
)
return
s_in_y
.
sum
()
s_y
,
_updates
=
scan
(
s_y
=
scan
(
fn
=
fn
,
outputs_info
=
[
None
],
sequences
=
[
pt
.
as_tensor
([
3
,
2
,
1
],
dtype
=
np
.
int64
)],
mode
=
mode
,
allow_gc
=
False
,
return_updates
=
False
,
)
f_cvm
=
function
([],
s_y
,
mode
=
mode
)
...
...
@@ -4121,14 +4249,14 @@ def test_rng_outputs_info():
)
.
owner
.
outputs
return
next_x
,
next_rng
[
xs
,
rng_final
]
,
updates
=
scan
(
[
xs
,
rng_final
]
=
scan
(
fn
=
step
,
outputs_info
=
[
x0
,
rng_x0
],
n_steps
=
10
,
return_updates
=
False
,
)
assert
isinstance
(
xs
.
type
,
TensorType
)
assert
isinstance
(
rng_final
.
type
,
RandomGeneratorType
)
assert
not
updates
fn
=
function
([
rng_init
],
[
xs
,
rng_final
])
xs_eval
,
rng_final_eval
=
fn
(
np
.
random
.
default_rng
(
0
))
...
...
tests/scan/test_rewriting.py
浏览文件 @
abedb7fb
...
...
@@ -47,38 +47,47 @@ class TestRemoveConstantsAndUnusedInputsScan:
"""Test the rewrite `remove_constants_and_unused_inputs_scan` for non-sequences."""
W
=
matrix
(
name
=
"W"
)
v
=
ivector
(
name
=
"v"
)
y1
,
_
=
scan
(
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
]
y1
=
scan
(
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
],
return_updates
=
False
,
)
y2
,
_
=
scan
(
y2
=
scan
(
lambda
i
,
_
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
[
0
],
W
],
return_updates
=
False
,
)
y3
,
_
=
scan
(
y3
=
scan
(
lambda
i
,
W
,
_
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
,
W
[
0
]],
return_updates
=
False
,
)
y4
,
_
=
scan
(
y4
=
scan
(
lambda
i
,
_
,
_2
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
[
0
],
W
[
0
],
W
],
return_updates
=
False
,
)
y5
,
_
=
scan
(
y5
=
scan
(
lambda
i
,
_
,
W
,
_2
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
[
0
],
W
,
W
[
0
]],
return_updates
=
False
,
)
y6
,
_
=
scan
(
y6
=
scan
(
lambda
i
,
W
,
_
,
_2
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
,
W
[
0
],
W
[
0
]],
return_updates
=
False
,
)
# TODO: y7 have problem during run time. I think it should
# raise an error during the scan construction.
...
...
@@ -112,47 +121,61 @@ class TestRemoveConstantsAndUnusedInputsScan:
W
=
matrix
(
name
=
"W"
)
v
=
ivector
(
name
=
"v"
)
vv
=
matrix
(
name
=
"vv"
)
y1
,
_
=
scan
(
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
]
y1
=
scan
(
lambda
i
,
W
:
W
[
i
],
sequences
=
v
,
outputs_info
=
None
,
non_sequences
=
[
W
],
return_updates
=
False
,
)
y2
,
_
=
scan
(
lambda
i
,
_
,
W
:
W
[
i
],
sequences
=
[
v
,
v
],
outputs_info
=
None
,
non_sequences
=
W
y2
=
scan
(
lambda
i
,
_
,
W
:
W
[
i
],
sequences
=
[
v
,
v
],
outputs_info
=
None
,
non_sequences
=
W
,
return_updates
=
False
,
)
y3
,
_
=
scan
(
y3
=
scan
(
lambda
i
,
_
,
W
:
W
[
i
],
sequences
=
[
v
,
vv
[
0
]],
outputs_info
=
None
,
non_sequences
=
W
,
return_updates
=
False
,
)
y4
,
_
=
scan
(
y4
=
scan
(
lambda
_
,
i
,
W
:
W
[
i
],
sequences
=
[
vv
[
0
],
v
],
outputs_info
=
None
,
non_sequences
=
W
,
return_updates
=
False
,
)
y5
,
_
=
scan
(
y5
=
scan
(
lambda
_
,
i
,
_2
,
W
:
W
[
i
],
sequences
=
[
vv
,
v
,
vv
[
0
]],
outputs_info
=
None
,
non_sequences
=
W
,
return_updates
=
False
,
)
y6
,
_
=
scan
(
y6
=
scan
(
lambda
_
,
_2
,
i
,
W
:
W
[
i
],
sequences
=
[
vv
[
0
],
vv
,
v
],
outputs_info
=
None
,
non_sequences
=
W
,
return_updates
=
False
,
)
y7
,
_
=
scan
(
y7
=
scan
(
lambda
i
,
_
,
_2
,
W
:
W
[
i
],
sequences
=
[
v
,
vv
[
0
],
vv
[
0
]],
outputs_info
=
None
,
non_sequences
=
W
,
return_updates
=
False
,
)
y8
,
_
=
scan
(
y8
=
scan
(
lambda
_
,
i
,
W
,
_2
,
_3
:
W
[
i
],
sequences
=
[
vv
[
0
],
v
],
outputs_info
=
None
,
non_sequences
=
[
W
,
W
[
0
],
W
[
0
]],
return_updates
=
False
,
)
W_val
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
)
...
...
@@ -195,7 +218,7 @@ class TestPushOutDot:
def
lambda_fn
(
h
,
W1
,
W2
):
return
dot
(
h
,
W1
+
W2
)
o
,
_
=
scan
(
lambda_fn
,
non_sequences
=
[
h0
,
W1
,
W2
],
n_steps
=
5
)
o
=
scan
(
lambda_fn
,
non_sequences
=
[
h0
,
W1
,
W2
],
n_steps
=
5
,
return_updates
=
False
)
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
self
.
mode
)
...
...
@@ -232,19 +255,24 @@ class TestPushOutDot:
return
dot
(
W1
,
W2
),
until_condition
# Compile a function with the optimization
o
,
_
=
scan
(
lambda_fn
,
sequences
=
[
step_indices
,
W1
],
non_sequences
=
[
W2
],
n_steps
=
5
o
=
scan
(
lambda_fn
,
sequences
=
[
step_indices
,
W1
],
non_sequences
=
[
W2
],
n_steps
=
5
,
return_updates
=
False
,
)
f
=
function
([
W1
,
W2
,
step_indices
],
o
,
mode
=
self
.
mode
)
# Compule an pytensor function without the optimization
o
,
_
=
scan
(
o
=
scan
(
lambda_fn
,
sequences
=
[
step_indices
,
W1
],
non_sequences
=
[
W2
],
n_steps
=
5
,
mode
=
"FAST_COMPILE"
,
return_updates
=
False
,
)
f_ref
=
function
([
W1
,
W2
,
step_indices
],
o
,
mode
=
self
.
mode
)
...
...
@@ -268,7 +296,13 @@ class TestPushOutDot:
def
lambda_fn
(
h
,
W1
,
W2
):
return
dot
(
h
,
W1
+
W2
)
o
,
_
=
scan
(
lambda_fn
,
outputs_info
=
h0
,
non_sequences
=
[
W1
,
W2
],
n_steps
=
5
)
o
=
scan
(
lambda_fn
,
outputs_info
=
h0
,
non_sequences
=
[
W1
,
W2
],
n_steps
=
5
,
return_updates
=
False
,
)
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
self
.
mode
)
...
...
@@ -290,10 +324,11 @@ class TestPushOutDot:
def
fn
(
i
,
i_tm1
):
return
i
+
10
,
i_tm1
([
i_t
,
i_tm1
],
_
)
=
scan
(
[
i_t
,
i_tm1
]
=
scan
(
fn
,
sequences
=
[
inp
],
outputs_info
=
[
np
.
asarray
([
0.0
,
0.0
],
config
.
floatX
),
None
],
return_updates
=
False
,
)
f
=
function
([
inp
],
[
i_t
,
i_tm1
])
val
=
np
.
arange
(
10
)
.
reshape
(
5
,
2
)
.
astype
(
config
.
floatX
)
...
...
@@ -397,17 +432,18 @@ class TestPushOutNonSeqScan:
@config.change_flags
(
on_opt_error
=
"raise"
)
def
test_pushout_seqs2
(
self
):
x
=
matrix
()
outputs
,
updates
=
scan
(
outputs
=
scan
(
lambda
x
:
[
x
*
x
,
pt
.
constant
(
0
)
.
copy
()
.
copy
()],
n_steps
=
2
,
sequences
=
[],
non_sequences
=
[],
outputs_info
=
[
x
,
None
],
return_updates
=
False
,
)
# Compile an PyTensor function where any optimization error will lead to
# an exception being raised
function
([
x
],
outputs
,
updates
=
updates
)
function
([
x
],
outputs
)
@config.change_flags
(
on_opt_error
=
"raise"
)
def
test_pushout_nonseq
(
self
):
...
...
@@ -418,7 +454,9 @@ class TestPushOutNonSeqScan:
outputs. This led the optimization to raise an exception.
"""
outputs
,
_
=
scan
(
lambda
x
:
(
x
*
x
,
x
),
non_sequences
=
[
2
],
n_steps
=
2
)
outputs
=
scan
(
lambda
x
:
(
x
*
x
,
x
),
non_sequences
=
[
2
],
n_steps
=
2
,
return_updates
=
False
)
f
=
function
(
inputs
=
[],
outputs
=
outputs
)
outs
=
f
()
...
...
@@ -583,10 +621,12 @@ class TestPushOutNonSeqScan:
test_ofg
=
OpFromGraph
([],
[
y
])
def
inner_func
(
x
):
out
,
_
=
pytensor
.
scan
(
lambda
:
test_ofg
(),
n_steps
=
x
)
out
=
pytensor
.
scan
(
lambda
:
test_ofg
(),
n_steps
=
x
,
return_updates
=
False
)
return
out
out
,
_
=
pytensor
.
scan
(
inner_func
,
sequences
=
[
pt
.
arange
(
1
,
2
)])
out
=
pytensor
.
scan
(
inner_func
,
sequences
=
[
pt
.
arange
(
1
,
2
)],
return_updates
=
False
)
_
=
pytensor
.
function
([],
test_ofg
())
...
...
@@ -612,10 +652,11 @@ class TestPushOutAddScan:
def
test_sum_dot
(
self
):
A
=
matrix
(
"A"
)
B
=
matrix
(
"B"
)
S
,
_
=
scan
(
S
=
scan
(
lambda
x1
,
x2
,
u
:
u
+
dot
(
x1
,
x2
),
sequences
=
[
A
.
dimshuffle
(
0
,
1
,
"x"
),
B
.
dimshuffle
(
0
,
"x"
,
1
)],
outputs_info
=
[
pt
.
zeros_like
(
A
)],
return_updates
=
False
,
)
# FIXME: This `s.owner.inputs[0][-1]` is a hack, users will never do that.
# They will do `s[-1]` which the rewrite fails to identify since it explicitly looks for a `scan_out[-1]`
...
...
@@ -636,13 +677,17 @@ class TestPushOutAddScan:
bv
=
pt
.
zeros
((
5
,))
bh
=
pt
.
zeros
((
4
,))
v
=
matrix
(
"v"
)
(
bv_t
,
bh_t
),
_
=
scan
(
lambda
_
:
[
bv
,
bh
],
sequences
=
v
,
outputs_info
=
[
None
,
None
]
(
bv_t
,
bh_t
)
=
scan
(
lambda
_
:
[
bv
,
bh
],
sequences
=
v
,
outputs_info
=
[
None
,
None
],
return_updates
=
False
,
)
chain
,
_
=
scan
(
chain
=
scan
(
lambda
x
:
dot
(
dot
(
x
,
W
)
+
bh_t
,
W
.
T
)
+
bv_t
,
outputs_info
=
v
,
n_steps
=
2
,
return_updates
=
False
,
)
# TODO FIXME: Make this a real test and assert something.
chain_fn
=
function
([
v
],
chain
)
...
...
@@ -710,26 +755,28 @@ class TestPushOutAddScan:
# Compile the function twice, once with the optimization and once
# without
opt_mode
=
mode
.
including
(
"scan"
)
h
,
_
=
pytensor
.
scan
(
h
=
pytensor
.
scan
(
rnn_step1
,
sequences
=
[
x
,
ri
,
zi
],
n_steps
=
seq_len
,
outputs_info
=
init
,
name
=
"fpass1"
,
mode
=
opt_mode
,
return_updates
=
False
,
)
cost
=
h
[
-
1
]
.
sum
()
grad1
=
grad
(
cost
,
[
U
,
V
,
W
])
f_opt
=
pytensor
.
function
(
inputs
=
[
x
,
ri
,
zi
],
outputs
=
grad1
,
mode
=
opt_mode
)
no_opt_mode
=
mode
.
excluding
(
"scan_pushout_add"
)
h
,
_
=
pytensor
.
scan
(
h
=
pytensor
.
scan
(
rnn_step1
,
sequences
=
[
x
,
ri
,
zi
],
n_steps
=
seq_len
,
outputs_info
=
init
,
name
=
"fpass1"
,
mode
=
no_opt_mode
,
return_updates
=
False
,
)
cost
=
h
[
-
1
]
.
sum
()
grad1
=
grad
(
cost
,
[
U
,
V
,
W
])
...
...
@@ -773,21 +820,23 @@ class TestPushOutAddScan:
# Compile the function twice, once with the optimization and once without
opt_mode
=
mode
.
including
(
"scan"
)
h
,
_
=
pytensor
.
scan
(
h
=
pytensor
.
scan
(
inner_fct
,
sequences
=
[
input1
,
input2
,
input3
],
outputs_info
=
init
,
mode
=
opt_mode
,
return_updates
=
False
,
)
output
=
h
[
-
1
]
f_opt
=
pytensor
.
function
([
input1
,
input2
,
input3
],
output
,
mode
=
opt_mode
)
no_opt_mode
=
mode
.
excluding
(
"scan_pushout_add"
)
h
,
_
=
pytensor
.
scan
(
h
=
pytensor
.
scan
(
inner_fct
,
sequences
=
[
input1
,
input2
,
input3
],
outputs_info
=
init
,
mode
=
no_opt_mode
,
return_updates
=
False
,
)
output
=
h
[
-
1
]
f_no_opt
=
pytensor
.
function
([
input1
,
input2
,
input3
],
output
,
mode
=
no_opt_mode
)
...
...
@@ -892,13 +941,20 @@ class TestScanMerge:
"""
inps
=
vector
()
state
=
scalar
()
y1
,
_
=
scan
(
lambda
x
,
y
:
x
*
y
,
sequences
=
inps
,
outputs_info
=
state
,
n_steps
=
5
)
y1
=
scan
(
lambda
x
,
y
:
x
*
y
,
sequences
=
inps
,
outputs_info
=
state
,
n_steps
=
5
,
return_updates
=
False
,
)
y2
,
_
=
scan
(
y2
=
scan
(
lambda
x
,
y
:
(
x
+
y
,
until
(
x
>
0
)),
sequences
=
inps
,
outputs_info
=
state
,
n_steps
=
5
,
return_updates
=
False
,
)
scan_node1
=
y1
.
owner
.
inputs
[
0
]
.
owner
assert
isinstance
(
scan_node1
.
op
,
Scan
)
...
...
@@ -958,8 +1014,8 @@ class TestScanMerge:
def
sub
(
s1
,
s2
,
const
):
return
s1
-
1
,
until
(
s2
>
const
)
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
]
)
sy
,
_
=
scan
(
sub
,
sequences
=
[
y
,
-
z
],
non_sequences
=
[
c1
]
)
sx
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
],
return_updates
=
False
)
sy
=
scan
(
sub
,
sequences
=
[
y
,
-
z
],
non_sequences
=
[
c1
],
return_updates
=
False
)
f
=
pytensor
.
function
(
inputs
=
[
x
,
y
,
z
,
c1
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
2
...
...
@@ -972,8 +1028,8 @@ class TestScanMerge:
np
.
testing
.
assert_array_equal
(
res_sx
,
[
1
,
1
])
np
.
testing
.
assert_array_equal
(
res_sy
,
[
-
1
,
-
1
,
-
1
,
-
1
,
-
1
])
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
]
)
sy
,
_
=
scan
(
sub
,
sequences
=
[
y
,
z
],
non_sequences
=
[
c2
]
)
sx
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
],
return_updates
=
False
)
sy
=
scan
(
sub
,
sequences
=
[
y
,
z
],
non_sequences
=
[
c2
],
return_updates
=
False
)
f
=
pytensor
.
function
(
inputs
=
[
x
,
y
,
z
,
c1
,
c2
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
...
...
@@ -989,22 +1045,23 @@ class TestScanMerge:
np
.
testing
.
assert_array_equal
(
res_sx
,
[
1
,
1
,
1
,
1
,
1
])
np
.
testing
.
assert_array_equal
(
res_sy
,
[
-
1
,
-
1
,
-
1
])
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
]
)
sy
,
_
=
scan
(
sub
,
sequences
=
[
y
,
z
],
non_sequences
=
[
c1
]
)
sx
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c1
],
return_updates
=
False
)
sy
=
scan
(
sub
,
sequences
=
[
y
,
z
],
non_sequences
=
[
c1
],
return_updates
=
False
)
f
=
pytensor
.
function
(
inputs
=
[
x
,
y
,
z
,
c1
],
outputs
=
[
sx
,
sy
],
mode
=
self
.
mode
)
assert
self
.
count_scans
(
f
)
==
1
def
nested_scan
(
c
,
x
,
z
):
sx
,
_
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c
]
)
sy
,
_
=
scan
(
sub
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c
]
)
sx
=
scan
(
add
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c
],
return_updates
=
False
)
sy
=
scan
(
sub
,
sequences
=
[
x
,
z
],
non_sequences
=
[
c
],
return_updates
=
False
)
return
sx
.
sum
()
+
sy
.
sum
()
sz
,
_
=
scan
(
sz
=
scan
(
nested_scan
,
sequences
=
[
stack
([
c1
,
c2
])],
non_sequences
=
[
x
,
z
],
mode
=
self
.
mode
,
return_updates
=
False
,
)
f
=
pytensor
.
function
(
inputs
=
[
x
,
z
,
c1
,
c2
],
outputs
=
sz
,
mode
=
mode
)
...
...
@@ -1023,9 +1080,8 @@ class TestScanInplaceOptimizer:
x
=
pt
.
vector
(
"x"
)
scan_out
,
_
=
pytensor
.
scan
(
lambda
x
:
(
x
+
1
)
/
2
+
1
,
sequences
=
[
x
],
scan_out
=
pytensor
.
scan
(
lambda
x
:
(
x
+
1
)
/
2
+
1
,
sequences
=
[
x
],
return_updates
=
False
)
fgraph
=
FunctionGraph
(
...
...
@@ -1039,10 +1095,8 @@ class TestScanInplaceOptimizer:
assert
equal_computations
([
scan_out
],
fgraph
.
outputs
)
def
test_inplace_basic
(
self
):
scan_out
,
_
=
pytensor
.
scan
(
lambda
x
:
x
+
1
,
outputs_info
=
[
pt
.
zeros
(
1
)],
n_steps
=
3
,
scan_out
=
pytensor
.
scan
(
lambda
x
:
x
+
1
,
outputs_info
=
[
pt
.
zeros
(
1
)],
n_steps
=
3
,
return_updates
=
False
)
fgraph
=
FunctionGraph
(
...
...
@@ -1089,7 +1143,7 @@ class TestScanInplaceOptimizer:
u0_t
*
W_in
+
x1_tm1
*
W
+
u1_t
+
u2_t
,
]
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_shared
,
[
u0
,
u1
,
u2
],
[
dict
(
initial
=
x0
,
inplace
=
u2
),
dict
(
initial
=
x1
,
inplace
=
u1
)],
...
...
@@ -1098,12 +1152,12 @@ class TestScanInplaceOptimizer:
truncate_gradient
=-
1
,
go_backwards
=
False
,
mode
=
self
.
mode
,
return_updates
=
False
,
)
f9
=
function
(
[
mu0
,
mu1
,
mu2
,
x0
,
x1
],
outputs
,
updates
=
updates
,
mode
=
self
.
mode
,
allow_input_downcast
=
True
,
)
...
...
@@ -1155,7 +1209,7 @@ class TestScanInplaceOptimizer:
u0_t
*
W_in
+
x1_tm1
*
W
+
u2_tm1
+
u2_t
+
u2_tp1
,
]
outputs
,
updates
=
scan
(
outputs
=
scan
(
f_rnn_shared
,
[
u0
,
dict
(
input
=
u1
,
taps
=
[
0
,
1
]),
dict
(
input
=
u2
,
taps
=
[
-
1
,
0
,
+
1
])],
[
dict
(
initial
=
x0
),
dict
(
initial
=
x1
)],
...
...
@@ -1164,11 +1218,11 @@ class TestScanInplaceOptimizer:
truncate_gradient
=-
1
,
go_backwards
=
False
,
mode
=
self
.
mode
,
return_updates
=
False
,
)
f9
=
function
(
[
mu0
,
mu1
,
mu2
,
x0
,
x1
],
outputs
,
updates
=
updates
,
mode
=
self
.
mode
,
allow_input_downcast
=
True
,
)
...
...
@@ -1202,8 +1256,12 @@ class TestScanInplaceOptimizer:
vx1
=
asarrayX
(
rng
.
uniform
())
x0
=
shared
(
vx0
)
x1
=
shared
(
vx1
)
outputs
,
updates
=
scan
(
lambda
x
,
y
:
(
x
+
asarrayX
(
1
),
y
+
asarrayX
(
1
)),
[],
[
x0
,
x1
],
n_steps
=
3
outputs
=
scan
(
lambda
x
,
y
:
(
x
+
asarrayX
(
1
),
y
+
asarrayX
(
1
)),
[],
[
x0
,
x1
],
n_steps
=
3
,
return_updates
=
False
,
)
x0
=
asarrayX
(
np
.
zeros
((
4
,)))
x0
[
0
]
=
vx0
...
...
@@ -1212,7 +1270,7 @@ class TestScanInplaceOptimizer:
to_replace
=
outputs
[
0
]
.
owner
.
inputs
[
0
]
.
owner
.
inputs
[
1
]
outputs
=
clone_replace
(
outputs
,
replace
=
[(
to_replace
,
x0
)])
f9
=
function
([],
outputs
,
updates
=
updates
,
mode
=
self
.
mode
)
f9
=
function
([],
outputs
,
mode
=
self
.
mode
)
scan_node
=
[
x
for
x
in
f9
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
assert
0
not
in
scan_node
[
0
]
.
op
.
destroy_map
assert
1
in
scan_node
[
0
]
.
op
.
destroy_map
...
...
@@ -1249,7 +1307,7 @@ class TestSaveMem:
y_tm1
+
dot
(
x_tm1
,
W_out
),
]
_outputs
,
update
s
=
scan
(
out
s
=
scan
(
f_rnn_cmpl
,
[
u1
,
u2
],
[
None
,
dict
(
initial
=
x0
),
dict
(
initial
=
y0
,
taps
=
[
-
1
,
-
3
])],
...
...
@@ -1257,12 +1315,12 @@ class TestSaveMem:
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
outputs
=
[
_outputs
[
0
][
-
1
],
_outputs
[
1
][
-
1
],
_outp
uts
[
2
][
-
1
]]
outputs
=
[
outs
[
0
][
-
1
],
outs
[
1
][
-
1
],
o
uts
[
2
][
-
1
]]
f4
=
function
(
[
u1
,
u2
,
x0
,
y0
,
W_in1
],
outputs
,
updates
=
updates
,
allow_input_downcast
=
True
,
mode
=
self
.
mode
,
)
...
...
@@ -1297,14 +1355,18 @@ class TestSaveMem:
u
=
vector
(
"u"
)
idx
=
iscalar
(
"idx"
)
jdx
=
iscalar
(
"jdx"
)
[
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
],
updates
=
scan
(
f_rnn
,
u
,
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
[
x1
,
x2
,
x3
,
x4
,
x5
,
x6
,
x7
]
=
scan
(
f_rnn
,
u
,
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
f2
=
function
(
[
u
,
idx
,
jdx
],
[
x1
[:
2
],
x2
[
4
],
x3
[
idx
],
x4
[:
idx
],
x5
[
-
10
],
x6
[
-
jdx
],
x7
[:
-
jdx
]],
updates
=
updates
,
allow_input_downcast
=
True
,
mode
=
self
.
mode
.
excluding
(
"scan_push_out_seq"
),
)
...
...
@@ -1341,10 +1403,8 @@ class TestSaveMem:
def
test_save_mem_reduced_number_of_steps_constant
(
self
):
x0
=
pt
.
scalar
(
"x0"
)
xs
,
_
=
scan
(
lambda
xtm1
:
xtm1
+
1
,
outputs_info
=
[
x0
],
n_steps
=
10
,
xs
=
scan
(
lambda
xtm1
:
xtm1
+
1
,
outputs_info
=
[
x0
],
n_steps
=
10
,
return_updates
=
False
)
fn
=
function
([
x0
],
xs
[:
5
],
mode
=
self
.
mode
)
...
...
@@ -1358,10 +1418,11 @@ class TestSaveMem:
def
test_save_mem_cannot_reduce_constant_number_of_steps
(
self
):
x0
=
pt
.
scalar
(
"x0"
)
[
xs
,
ys
]
,
_
=
scan
(
[
xs
,
ys
]
=
scan
(
lambda
xtm1
,
ytm1
:
(
xtm1
+
1
,
ytm1
-
1
),
outputs_info
=
[
x0
,
x0
],
n_steps
=
10
,
return_updates
=
False
,
)
# Because of ys[-1] we need all the steps!
...
...
@@ -1399,7 +1460,7 @@ class TestSaveMem:
x20
=
scalar
(
"x20"
)
x30
=
vector
(
"x30"
)
x40
=
scalar
(
"x40"
)
[
x1
,
x2
,
x3
,
x4
,
x5
,
_x6
,
_x7
]
,
updates
=
scan
(
[
x1
,
x2
,
x3
,
x4
,
x5
,
_x6
,
_x7
]
=
scan
(
step
,
u
,
[
...
...
@@ -1414,12 +1475,12 @@ class TestSaveMem:
n_steps
=
None
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
f
=
function
(
[
u
,
x10
,
x20
,
x30
,
x40
],
[
x1
[
-
7
],
x2
[
-
3
:
-
1
],
x3
[
-
6
:],
x4
[
-
1
],
x5
[
-
1
]],
updates
=
updates
,
allow_input_downcast
=
True
,
mode
=
self
.
mode
,
)
...
...
@@ -1479,10 +1540,11 @@ class TestSaveMem:
def
test_savemem_does_not_duplicate_number_of_scan_nodes
(
self
):
var
=
pt
.
ones
(())
values
,
_
=
scan
(
values
=
scan
(
lambda
x
:
([
x
],
(),
until
(
x
)),
outputs_info
=
[
var
],
n_steps
=
2
,
return_updates
=
False
,
)
tmp_fn
=
function
([
var
],
values
,
mode
=
self
.
mode
)
...
...
@@ -1493,10 +1555,11 @@ class TestSaveMem:
def
test_savemem_opt
(
self
,
benchmark
):
y0
=
shared
(
np
.
ones
((
2
,
10
)))
[
_y1
,
y2
]
,
_updates
=
scan
(
[
_y1
,
y2
]
=
scan
(
lambda
y
:
[
y
,
y
],
outputs_info
=
[
dict
(
initial
=
y0
,
taps
=
[
-
2
]),
None
],
n_steps
=
5
,
return_updates
=
False
,
)
# TODO FIXME: Make this a real test and assert something.
fn
=
function
([],
y2
.
sum
(),
mode
=
self
.
mode
)
...
...
@@ -1515,23 +1578,25 @@ class TestSaveMem:
return
dot
(
h_tm1
,
w
)
+
x_t_t
def
outer_scan_step
(
x_t
,
w
):
h
,
_
=
scan
(
h
=
scan
(
inner_scan_step
,
sequences
=
[
x_t
[
1
:]],
outputs_info
=
[
x_t
[
0
]],
non_sequences
=
[
w
],
strict
=
True
,
name
=
"the_inner_scan"
,
return_updates
=
False
,
)
return
h
def
get_outputs
(
x
,
w
):
features
,
_
=
scan
(
features
=
scan
(
outer_scan_step
,
sequences
=
[
x
],
non_sequences
=
[
w
],
strict
=
True
,
name
=
"the_outer_scan"
,
return_updates
=
False
,
)
return_val
=
grad
(
features
.
sum
(),
w
)
...
...
@@ -1571,7 +1636,7 @@ class TestSaveMem:
state
=
vector
(
"state"
)
n_steps
=
iscalar
(
"nsteps"
)
output
,
updates
=
scan
(
output
=
scan
(
f_pow2
,
[],
state
,
...
...
@@ -1579,13 +1644,13 @@ class TestSaveMem:
n_steps
=
n_steps
,
truncate_gradient
=-
1
,
go_backwards
=
False
,
return_updates
=
False
,
)
nw_shape
=
ivector
(
"nw_shape"
)
# Note that the output is reshaped to 3 dimensional tensor, and
my_f
=
function
(
[
state
,
n_steps
,
nw_shape
],
[
reshape
(
output
,
nw_shape
,
ndim
=
3
)[:
-
2
],
output
[:
-
4
]],
updates
=
updates
,
allow_input_downcast
=
True
,
)
nodes
=
[
x
for
x
in
my_f
.
maker
.
fgraph
.
toposort
()
if
isinstance
(
x
.
op
,
Scan
)]
...
...
@@ -1599,11 +1664,12 @@ class TestSaveMem:
n_steps
=
scalar
(
"n_steps"
,
dtype
=
"int64"
)
x0
=
vector
(
"x0"
)
ys
,
_
=
pytensor
.
scan
(
ys
=
pytensor
.
scan
(
# Fibonacci Sequence
lambda
xtm2
,
xtm1
:
(
xtm1
+
xtm2
,
{},
until
(
xtm1
>=
34
)),
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
2
,
-
1
]}],
n_steps
=
n_steps
,
return_updates
=
False
,
)
# Save memory is triggered by choosing only last value
y
=
ys
[
-
1
]
...
...
@@ -1629,10 +1695,11 @@ class TestSaveMem:
def
test_while_scan_map
(
self
):
xs
=
vector
(
"xs"
)
ys
,
_
=
pytensor
.
scan
(
ys
=
pytensor
.
scan
(
lambda
x
:
(
x
+
1
,
{},
until
(
x
+
1
>=
10
)),
outputs_info
=
[
None
],
sequences
=
[
xs
],
return_updates
=
False
,
)
# Save memory is triggered by choosing only last value
y
=
ys
[
-
1
]
...
...
@@ -1656,11 +1723,12 @@ class TestSaveMem:
n_steps
=
scalar
(
"n_steps"
,
dtype
=
"int64"
)
# while loop
[
ys
,
zs
]
,
_
=
pytensor
.
scan
(
[
ys
,
zs
]
=
pytensor
.
scan
(
lambda
s
,
xtm1
:
((
xtm1
+
1
,
xtm1
+
1
+
s
),
{},
until
(
xtm1
>=
99
)),
sequences
=
[
seq
],
outputs_info
=
[
x0
,
None
],
n_steps
=
n_steps
,
return_updates
=
False
,
)
# Save memory is triggered by choosing only last value
y
=
ys
[
-
1
]
...
...
@@ -1696,10 +1764,11 @@ class TestSaveMem:
val_test
=
np
.
zeros
(
val_shape
,
dtype
=
val
.
dtype
)
init
=
pt
.
full
((
2
,),
val
)
ys
,
_
=
pytensor
.
scan
(
ys
=
pytensor
.
scan
(
fn
=
lambda
*
args
:
pt
.
add
(
*
args
),
outputs_info
=
[{
"initial"
:
init
,
"taps"
:
(
-
2
,
-
1
)}],
n_steps
=
100
,
return_updates
=
False
,
)
out
=
ys
[:
-
50
]
if
keep_beginning
else
ys
[
-
50
:]
...
...
@@ -1729,12 +1798,13 @@ def test_inner_replace_dot():
mode
=
get_default_mode
()
.
including
(
"scan"
)
# .excluding("BlasOpt")
o
,
_
=
scan
(
o
=
scan
(
lambda
hi
,
him1
,
W
:
(
hi
,
dot
(
hi
+
him1
,
W
)),
outputs_info
=
[
pt
.
zeros
([
h
.
shape
[
1
]]),
None
],
sequences
=
[
h
],
non_sequences
=
[
W
],
mode
=
mode
,
return_updates
=
False
,
)
f
=
function
([
W
,
h
],
o
,
mode
=
mode
)
...
...
@@ -1753,11 +1823,12 @@ def test_alloc_inputs1():
def
lambda_fn
(
h
,
W1
,
W2
):
return
dot
(
h
,
W1
*
W2
)
o
,
_
=
scan
(
o
=
scan
(
lambda_fn
,
outputs_info
=
h0
,
non_sequences
=
[
W1
,
pt
.
zeros_like
(
W2
)],
n_steps
=
5
,
return_updates
=
False
,
)
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
get_default_mode
()
.
including
(
"scan"
))
...
...
@@ -1786,12 +1857,13 @@ def test_alloc_inputs2():
def
lambda_fn
(
W1
,
h
,
W2
):
return
W1
*
dot
(
h
,
W2
)
o
,
_
=
scan
(
o
=
scan
(
lambda_fn
,
sequences
=
pt
.
zeros_like
(
W1
),
outputs_info
=
h0
,
non_sequences
=
[
pt
.
zeros_like
(
W2
)],
n_steps
=
5
,
return_updates
=
False
,
)
f
=
function
([
h0
,
W1
,
W2
],
o
,
mode
=
get_default_mode
()
.
including
(
"scan"
))
...
...
@@ -1821,12 +1893,13 @@ def test_alloc_inputs3():
def
lambda_fn
(
W1
,
h
,
W2
):
return
W1
*
dot
(
h
,
W2
)
o
,
_
=
scan
(
o
=
scan
(
lambda_fn
,
sequences
=
pt
.
zeros_like
(
W1
),
outputs_info
=
h0
,
non_sequences
=
[
pt
.
zeros_like
(
W2
)],
n_steps
=
5
,
return_updates
=
False
,
)
# TODO FIXME: This result depends on unrelated rewrites in the "fast" mode.
...
...
@@ -1848,7 +1921,7 @@ def test_opt_order():
x
=
matrix
(
"x"
)
A
=
matrix
(
"A"
)
z
,
_updates
=
scan
(
dot
,
sequences
=
[],
non_sequences
=
[
x
,
A
],
n_steps
=
2
)
z
=
scan
(
dot
,
sequences
=
[],
non_sequences
=
[
x
,
A
],
n_steps
=
2
,
return_updates
=
False
)
f
=
function
([
x
,
A
],
z
,
mode
=
"FAST_RUN"
)
topo
=
f
.
maker
.
fgraph
.
toposort
()
...
...
tests/tensor/linalg/test_rewriting.py
浏览文件 @
abedb7fb
...
...
@@ -170,11 +170,12 @@ def test_lu_decomposition_reused_scan(assume_a, counter, transposed):
A
=
tensor
(
"A"
,
shape
=
(
3
,
3
))
x0
=
tensor
(
"b"
,
shape
=
(
3
,
4
))
xs
,
_
=
scan
(
xs
=
scan
(
lambda
xtm1
,
A
:
solve
(
A
,
xtm1
,
assume_a
=
assume_a
,
transposed
=
transposed
),
outputs_info
=
[
x0
],
non_sequences
=
[
A
],
n_steps
=
10
,
return_updates
=
False
,
)
fn_no_opt
=
function
(
...
...
tests/tensor/test_blockwise.py
浏览文件 @
abedb7fb
...
...
@@ -694,10 +694,11 @@ def test_blockwise_grad_core_type():
def
test_scan_gradient_core_type
():
n_steps
=
3
seq
=
tensor
(
"seq"
,
shape
=
(
n_steps
,
1
),
dtype
=
"float64"
)
out
,
_
=
scan
(
out
=
scan
(
lambda
s
:
s
,
sequences
=
[
seq
],
n_steps
=
n_steps
,
return_updates
=
False
,
)
vec_seq
=
tensor
(
"vec_seq"
,
shape
=
(
None
,
n_steps
,
1
),
dtype
=
"float64"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论