Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
63f8d6e7
提交
63f8d6e7
authored
2月 14, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
2月 24, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Optimize while scans when only last state is needed
上级
01e92baa
隐藏空白字符变更
内嵌
并排
正在显示
4 个修改的文件
包含
243 行增加
和
27 行删除
+243
-27
op.py
pytensor/scan/op.py
+1
-1
rewriting.py
pytensor/scan/rewriting.py
+139
-26
subtensor.py
pytensor/tensor/rewriting/subtensor.py
+11
-0
test_rewriting.py
tests/scan/test_rewriting.py
+92
-0
没有找到文件。
pytensor/scan/op.py
浏览文件 @
63f8d6e7
...
@@ -1182,7 +1182,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
...
@@ -1182,7 +1182,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# these are states that do not feed anything back in the recurrent
# these are states that do not feed anything back in the recurrent
# computation, and hence they do not have an initial state. The scan
# computation, and hence they do not have an initial state. The scan
# node however receives an input for each such argument, the input
# node however receives an input for each such argument, the input
# in this case is just a int saying how many steps of this output we
# in this case is just a
n
int saying how many steps of this output we
# need to store. This input does not have the same dtype, nor is it the same
# need to store. This input does not have the same dtype, nor is it the same
# type of tensor as the output, it is always a scalar int.
# type of tensor as the output, it is always a scalar int.
new_inputs
+=
[
as_tensor_variable
(
ons
)
for
ons
in
self
.
outer_nitsot
(
inputs
)]
new_inputs
+=
[
as_tensor_variable
(
ons
)
for
ons
in
self
.
outer_nitsot
(
inputs
)]
...
...
pytensor/scan/rewriting.py
浏览文件 @
63f8d6e7
...
@@ -28,10 +28,18 @@ from pytensor.graph.features import ReplaceValidate
...
@@ -28,10 +28,18 @@ from pytensor.graph.features import ReplaceValidate
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.op
import
compute_test_value
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.replace
import
clone_replace
from
pytensor.graph.rewriting.basic
import
GraphRewriter
,
in2out
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
(
GraphRewriter
,
copy_stack_trace
,
in2out
,
node_rewriter
,
)
from
pytensor.graph.rewriting.db
import
EquilibriumDB
,
SequenceDB
from
pytensor.graph.rewriting.db
import
EquilibriumDB
,
SequenceDB
from
pytensor.graph.rewriting.utils
import
get_clients_at_depth
from
pytensor.graph.type
import
HasShape
from
pytensor.graph.type
import
HasShape
from
pytensor.graph.utils
import
InconsistencyError
from
pytensor.graph.utils
import
InconsistencyError
from
pytensor.raise_op
import
Assert
from
pytensor.scalar
import
ScalarConstant
from
pytensor.scan.op
import
Scan
,
ScanInfo
from
pytensor.scan.op
import
Scan
,
ScanInfo
from
pytensor.scan.utils
import
(
from
pytensor.scan.utils
import
(
ScanArgs
,
ScanArgs
,
...
@@ -1103,6 +1111,71 @@ def sanitize(x):
...
@@ -1103,6 +1111,71 @@ def sanitize(x):
return
at
.
as_tensor_variable
(
x
)
return
at
.
as_tensor_variable
(
x
)
@node_rewriter
([
Scan
])
def
while_scan_merge_subtensor_last_element
(
fgraph
,
scan_node
):
"""
Replace while_scan_out[abs(min(tap)):][-1] by while_scan_out[-1], for
recurring outputs, asserting that at least one step occurs.
Only the first step can be ensured by the inputs alone (i.e., `n_steps > 0`),
as the while scan could abort earlier anytime after that. This means it is
not possible to replace while_scan_out[abs(min(tap)):][-i]
by while_scan_out[-i], for -i != -1.
"""
op
=
scan_node
.
op
if
not
op
.
info
.
as_while
:
return
None
# Optimization is not implemented form mit-mot
recurrent_outputs
=
op
.
outer_mitsot_outs
(
scan_node
.
outputs
)
+
op
.
outer_sitsot_outs
(
scan_node
.
outputs
)
recurrent_outputs_taps_slices
=
(
op
.
info
.
mit_sot_in_slices
+
op
.
info
.
sit_sot_in_slices
)
n_steps
=
scan_node
.
inputs
[
0
]
non_zero_steps_cond
=
n_steps
>
0
assert_non_zero_steps_op
=
Assert
(
"n_steps > 0"
)
subtensor_merge_replacements
=
{}
# Iterate over all nodes that are two computations below the while scan
for
node2
in
get_clients_at_depth
(
fgraph
,
scan_node
,
depth
=
2
):
if
not
isinstance
(
node2
.
op
,
Subtensor
):
continue
node1
=
node2
.
inputs
[
0
]
.
owner
if
not
(
node1
and
isinstance
(
node1
.
op
,
Subtensor
)):
continue
x
=
node1
.
inputs
[
0
]
if
x
not
in
recurrent_outputs
:
continue
slice1
=
get_idx_list
(
node1
.
inputs
,
node1
.
op
.
idx_list
)
slice2
=
get_idx_list
(
node2
.
inputs
,
node2
.
op
.
idx_list
)
min_tap
=
abs
(
min
(
recurrent_outputs_taps_slices
[
recurrent_outputs
.
index
(
x
)]))
if
(
len
(
slice1
)
==
1
and
isinstance
(
slice1
[
0
],
slice
)
and
isinstance
(
slice1
[
0
]
.
start
,
aes
.
ScalarConstant
)
and
slice1
[
0
]
.
start
.
data
==
min_tap
and
slice1
[
0
]
.
stop
is
None
and
slice1
[
0
]
.
step
is
None
and
len
(
slice2
)
==
1
and
isinstance
(
slice2
[
0
],
aes
.
ScalarConstant
)
and
slice2
[
0
]
.
data
==
-
1
):
out
=
assert_non_zero_steps_op
(
x
[
-
1
],
non_zero_steps_cond
)
copy_stack_trace
([
node2
.
outputs
[
0
],
node2
.
inputs
[
0
]],
out
)
subtensor_merge_replacements
[
node2
.
outputs
[
0
]]
=
out
return
subtensor_merge_replacements
@node_rewriter
([
Scan
])
@node_rewriter
([
Scan
])
def
save_mem_new_scan
(
fgraph
,
node
):
def
save_mem_new_scan
(
fgraph
,
node
):
r"""Graph optimizer that reduces scan memory consumption.
r"""Graph optimizer that reduces scan memory consumption.
...
@@ -1124,6 +1197,17 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1124,6 +1197,17 @@ def save_mem_new_scan(fgraph, node):
that SITSOT output. Only the most recently computed timestep ever needs to
that SITSOT output. Only the most recently computed timestep ever needs to
be kept in memory.
be kept in memory.
There are two ways in which the Scan buffer size is controlled:
1. Each recurring output is saved in an input empty tensor x with the initial
state written at x[:abs(min(taps))]. The remaining x[abs(min(taps)):]
positions determine how many intermediate results should be stored.
This rewrite shortens x[abs(min(taps)):] to the smallest possible size.
2. Each non-recurrent output (nit-sot) is associated with a scalar integer
input that determines how many steps should be saved in the perform method.
This rewrite reduces this number to the smallest possible.
The scan perform implementation takes the output sizes into consideration,
saving the newest results over the oldest ones whenever the buffer is filled.
"""
"""
if
not
isinstance
(
node
.
op
,
Scan
):
if
not
isinstance
(
node
.
op
,
Scan
):
return
False
return
False
...
@@ -1172,13 +1256,16 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1172,13 +1256,16 @@ def save_mem_new_scan(fgraph, node):
# index(step) for any output scan actually needs to compute
# index(step) for any output scan actually needs to compute
# In other words n_steps should be equal to this maximal !
# In other words n_steps should be equal to this maximal !
# Note: if we have a shared variable that gets updated at every step
# Note: if we have a shared variable that gets updated at every step
# of the loop, reducing the number of steps will affect the
the
# of the loop, reducing the number of steps will affect the
# value of the shared variable after the loop so we
need not to
# value of the shared variable after the loop so we
cannot
# change the number of steps in that case. To do this we set
# change the number of steps in that case. To do this we set
# global_nsteps to None which is seen as a flag that nothing needs
# global_nsteps to None which is seen as a flag that nothing needs
# to be done
# to be done.
# Note: For simplicity while Scans also have global_nsteps set to None.
# All step optimizations require knowing the shape of the output, which
# cannot be determined from the inputs alone.
assert
len
(
node
.
outputs
)
>=
c_outs
assert
len
(
node
.
outputs
)
>=
c_outs
if
len
(
node
.
outputs
)
==
c_outs
:
if
len
(
node
.
outputs
)
==
c_outs
and
not
op
.
info
.
as_while
:
global_nsteps
=
{
"real"
:
-
1
,
"sym"
:
[]}
global_nsteps
=
{
"real"
:
-
1
,
"sym"
:
[]}
else
:
else
:
global_nsteps
=
None
global_nsteps
=
None
...
@@ -1257,7 +1344,7 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1257,7 +1344,7 @@ def save_mem_new_scan(fgraph, node):
else
:
else
:
# there is a **gotcha** here ! Namely, scan returns an
# there is a **gotcha** here ! Namely, scan returns an
# array that contains the initial state of the output
# array that contains the initial state of the output
# as well. Which means that if
have a
initial state of
# as well. Which means that if
y has an
initial state of
# length 3, and you look for 5 steps you get an output
# length 3, and you look for 5 steps you get an output
# y of length 8. If you only use y[:5], this does not
# y of length 8. If you only use y[:5], this does not
# mean that you only need to loop for 5 steps but
# mean that you only need to loop for 5 steps but
...
@@ -1285,9 +1372,9 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1285,9 +1372,9 @@ def save_mem_new_scan(fgraph, node):
# 2.3. Analyze global_nsteps to figure out for how many steps scan
# 2.3. Analyze global_nsteps to figure out for how many steps scan
# needs to iterate
# needs to iterate
if
global_nsteps
is
not
None
:
if
global_nsteps
is
None
:
nw_steps
=
node
.
inputs
[
0
]
nw_steps
=
node
.
inputs
[
0
]
else
:
# there are some symbolic tensors that limit the number of
# there are some symbolic tensors that limit the number of
# steps
# steps
if
len
(
global_nsteps
[
"sym"
])
==
0
:
if
len
(
global_nsteps
[
"sym"
])
==
0
:
...
@@ -1303,6 +1390,7 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1303,6 +1390,7 @@ def save_mem_new_scan(fgraph, node):
real_steps
=
None
real_steps
=
None
nw_steps
=
select_min
(
select_max
(
sym_steps
,
real_steps
),
node
.
inputs
[
0
])
nw_steps
=
select_min
(
select_max
(
sym_steps
,
real_steps
),
node
.
inputs
[
0
])
# FIXME: This is not correct. Scan with 0 steps seems to be supported
# Make sure the ScanSaveMem optimization never makes the new
# Make sure the ScanSaveMem optimization never makes the new
# number of steps to be 0 (this could happen, for instance, if
# number of steps to be 0 (this could happen, for instance, if
# the optimization detects that the outputs of the Scan go through
# the optimization detects that the outputs of the Scan go through
...
@@ -1310,9 +1398,6 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1310,9 +1398,6 @@ def save_mem_new_scan(fgraph, node):
# 0 iterations are not supported. Make sure the new number of steps
# 0 iterations are not supported. Make sure the new number of steps
# is at least 1.
# is at least 1.
nw_steps
=
select_max
(
nw_steps
,
1
)
nw_steps
=
select_max
(
nw_steps
,
1
)
else
:
nw_steps
=
node
.
inputs
[
0
]
global_nsteps
=
None
# 2.4 Loop over the clients again now looking just to see how many
# 2.4 Loop over the clients again now looking just to see how many
# intermediate steps to store
# intermediate steps to store
...
@@ -1335,19 +1420,33 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1335,19 +1420,33 @@ def save_mem_new_scan(fgraph, node):
store_steps
[
i
]
=
0
store_steps
[
i
]
=
0
break
break
if
i
>
op_info
.
n_mit_mot
:
# Special case for recurrent outputs where only the last result
length
=
node
.
inputs
[
0
]
+
init_l
[
i
]
# is requested. This is needed for this rewrite to apply to
# do-while Scans at all. Otherwise, `get_canonical_form_slice` in
# the `else` branch would reintroduce a shape dependency on the
# original Scan which would lead this rewrite to abort in the end.
if
(
i
<=
op
.
info
.
n_mit_mot
and
isinstance
(
this_slice
[
0
],
ScalarConstant
)
and
this_slice
[
0
]
.
value
==
-
1
):
start
=
nw_steps
-
1
else
:
else
:
try
:
if
i
<=
op
.
info
.
n_mit_mot
:
length
=
shape_of
[
out
][
0
]
try
:
except
KeyError
:
length
=
shape_of
[
out
][
0
]
length
=
out
.
shape
[
0
]
except
KeyError
:
cf_slice
=
get_canonical_form_slice
(
this_slice
[
0
],
length
)
length
=
out
.
shape
[
0
]
else
:
length
=
node
.
inputs
[
0
]
+
init_l
[
i
]
cf_slice
=
get_canonical_form_slice
(
this_slice
[
0
],
length
)
if
isinstance
(
cf_slice
[
0
],
slice
):
start
=
at
.
extract_constant
(
cf_slice
[
0
]
.
start
)
else
:
start
=
at
.
extract_constant
(
cf_slice
[
0
])
if
isinstance
(
cf_slice
[
0
],
slice
):
start
=
at
.
extract_constant
(
cf_slice
[
0
]
.
start
)
else
:
start
=
at
.
extract_constant
(
cf_slice
[
0
])
if
start
==
0
or
store_steps
[
i
]
==
0
:
if
start
==
0
or
store_steps
[
i
]
==
0
:
store_steps
[
i
]
=
0
store_steps
[
i
]
=
0
else
:
else
:
...
@@ -1498,6 +1597,7 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1498,6 +1597,7 @@ def save_mem_new_scan(fgraph, node):
nw_input
=
expand_empty
(
_nw_input
,
nw_steps
)
nw_input
=
expand_empty
(
_nw_input
,
nw_steps
)
nw_inputs
[
in_idx
]
=
nw_input
nw_inputs
[
in_idx
]
=
nw_input
else
:
else
:
# FIXME: This is never used
nw_input
=
nw_inputs
[
in_idx
][:
(
initl
+
nw_steps
)]
nw_input
=
nw_inputs
[
in_idx
][:
(
initl
+
nw_steps
)]
elif
(
elif
(
...
@@ -1554,8 +1654,8 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1554,8 +1654,8 @@ def save_mem_new_scan(fgraph, node):
)
)
else
:
else
:
fslice
=
sanitize
(
cnf_slice
[
0
])
fslice
=
sanitize
(
cnf_slice
[
0
])
nw_slice
=
(
fslice
,)
+
tuple
(
old_slices
[
1
:])
nw_slice
=
(
fslice
,)
+
tuple
(
old_slices
[
1
:])
nw_pos
=
inv_compress_map
[
idx
]
nw_pos
=
inv_compress_map
[
idx
]
subtens
=
Subtensor
(
nw_slice
)
subtens
=
Subtensor
(
nw_slice
)
...
@@ -1604,9 +1704,16 @@ def save_mem_new_scan(fgraph, node):
...
@@ -1604,9 +1704,16 @@ def save_mem_new_scan(fgraph, node):
)
+
tuple
(
old_slices
[
1
:])
)
+
tuple
(
old_slices
[
1
:])
else
:
else
:
position
=
(
# Special case when only last value is requested
cnf_slice
[
0
]
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
]
if
(
)
isinstance
(
old_slices
[
0
],
ScalarConstant
)
and
old_slices
[
0
]
.
value
==
-
1
):
position
=
old_slices
[
0
]
else
:
position
=
(
cnf_slice
[
0
]
-
nw_steps
-
init_l
[
pos
]
+
store_steps
[
pos
]
)
nw_slice
=
(
sanitize
(
position
),)
+
tuple
(
old_slices
[
1
:])
nw_slice
=
(
sanitize
(
position
),)
+
tuple
(
old_slices
[
1
:])
subtens
=
Subtensor
(
nw_slice
)
subtens
=
Subtensor
(
nw_slice
)
...
@@ -2403,6 +2510,12 @@ scan_seqopt1.register(
...
@@ -2403,6 +2510,12 @@ scan_seqopt1.register(
position
=
5
,
position
=
5
,
)
)
scan_eqopt2
.
register
(
"while_scan_merge_subtensor_last_element"
,
in2out
(
while_scan_merge_subtensor_last_element
,
ignore_newtrees
=
True
),
"fast_run"
,
"scan"
,
)
scan_eqopt2
.
register
(
scan_eqopt2
.
register
(
"constant_folding_for_scan2"
,
"constant_folding_for_scan2"
,
...
...
pytensor/tensor/rewriting/subtensor.py
浏览文件 @
63f8d6e7
...
@@ -479,6 +479,7 @@ def local_subtensor_merge(fgraph, node):
...
@@ -479,6 +479,7 @@ def local_subtensor_merge(fgraph, node):
expresses all slices in a canonical form, and then merges them together.
expresses all slices in a canonical form, and then merges them together.
"""
"""
from
pytensor.scan.op
import
Scan
if
isinstance
(
node
.
op
,
Subtensor
):
if
isinstance
(
node
.
op
,
Subtensor
):
u
=
node
.
inputs
[
0
]
u
=
node
.
inputs
[
0
]
...
@@ -489,6 +490,16 @@ def local_subtensor_merge(fgraph, node):
...
@@ -489,6 +490,16 @@ def local_subtensor_merge(fgraph, node):
# slices of the first applied subtensor
# slices of the first applied subtensor
slices1
=
get_idx_list
(
u
.
owner
.
inputs
,
u
.
owner
.
op
.
idx_list
)
slices1
=
get_idx_list
(
u
.
owner
.
inputs
,
u
.
owner
.
op
.
idx_list
)
slices2
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
slices2
=
get_idx_list
(
node
.
inputs
,
node
.
op
.
idx_list
)
# Don't try to do the optimization on do-while scan outputs,
# as it will create a dependency on the shape of the outputs
if
(
x
.
owner
is
not
None
and
isinstance
(
x
.
owner
.
op
,
Scan
)
and
x
.
owner
.
op
.
info
.
as_while
):
return
None
# Get the shapes of the vectors !
# Get the shapes of the vectors !
try
:
try
:
# try not to introduce new shape into the graph
# try not to introduce new shape into the graph
...
...
tests/scan/test_rewriting.py
浏览文件 @
63f8d6e7
...
@@ -1395,6 +1395,98 @@ class TestSaveMem:
...
@@ -1395,6 +1395,98 @@ class TestSaveMem:
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
my_f
(
rng
.
uniform
(
size
=
(
3
,)),
4
,
np
.
int64
([
2
,
2
,
3
]))
my_f
(
rng
.
uniform
(
size
=
(
3
,)),
4
,
np
.
int64
([
2
,
2
,
3
]))
def
test_while_scan_taps
(
self
):
n_steps
=
scalar
(
"n_steps"
,
dtype
=
"int64"
)
x0
=
vector
(
"x0"
)
ys
,
_
=
pytensor
.
scan
(
# Fibonacci Sequence
lambda
xtm2
,
xtm1
:
(
xtm1
+
xtm2
,
{},
until
(
xtm1
>=
34
)),
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
2
,
-
1
]}],
n_steps
=
n_steps
,
)
# Save memory is triggered by choosing only last value
y
=
ys
[
-
1
]
f
=
pytensor
.
function
(
[
n_steps
,
x0
],
y
,
mode
=
get_default_mode
()
.
including
(
"scan"
)
)
np
.
testing
.
assert_equal
(
f
(
n_steps
=
1000
,
x0
=
[
1
,
1
]),
55
)
np
.
testing
.
assert_equal
(
f
(
n_steps
=
1
,
x0
=
[
1
,
1
]),
2
)
with
pytest
.
raises
(
AssertionError
,
match
=
"n_steps > 0"
):
f
(
n_steps
=
0
,
x0
=
[
1
,
1
])
# ys_trace is an Alloc that controls the size of the inner buffer,
# it should have shape[0] == 3, with two entries for the taps and one
# entry for the intermediate output
[
scan_node
]
=
(
n
for
n
in
f
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
n
.
op
,
Scan
))
_
,
ys_trace
=
scan_node
.
inputs
debug_fn
=
pytensor
.
function
(
[
n_steps
,
x0
],
ys_trace
.
shape
[
0
],
accept_inplace
=
True
)
assert
debug_fn
(
n_steps
=
1000
,
x0
=
[
1
,
1
])
==
3
def
test_while_scan_map
(
self
):
xs
=
vector
(
"xs"
)
ys
,
_
=
pytensor
.
scan
(
lambda
x
:
(
x
+
1
,
{},
until
(
x
+
1
>=
10
)),
outputs_info
=
[
None
],
sequences
=
[
xs
],
)
# Save memory is triggered by choosing only last value
y
=
ys
[
-
1
]
f
=
pytensor
.
function
([
xs
],
y
,
mode
=
get_default_mode
()
.
including
(
"scan"
))
np
.
testing
.
assert_equal
(
f
(
xs
=
np
.
arange
(
100
,
dtype
=
config
.
floatX
)),
10
)
np
.
testing
.
assert_equal
(
f
(
xs
=
[
0
]),
1
)
with
pytest
.
raises
(
IndexError
):
f
(
xs
=
[])
# len_ys is a numerical input that controls the shape of the inner buffer
# It should be 1, as only the last output is needed
[
scan_node
]
=
(
n
for
n
in
f
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
n
.
op
,
Scan
))
_
,
_
,
len_ys
=
scan_node
.
inputs
debug_fn
=
pytensor
.
function
([
xs
],
len_ys
,
accept_inplace
=
True
)
assert
debug_fn
(
xs
=
np
.
zeros
((
100
,),
dtype
=
config
.
floatX
))
==
1
def
test_while_scan_taps_and_map
(
self
):
x0
=
scalar
(
"x0"
)
seq
=
vector
(
"seq"
)
n_steps
=
scalar
(
"n_steps"
,
dtype
=
"int64"
)
# while loop
[
ys
,
zs
],
_
=
pytensor
.
scan
(
lambda
s
,
xtm1
:
((
xtm1
+
1
,
xtm1
+
1
+
s
),
{},
until
(
xtm1
>=
99
)),
sequences
=
[
seq
],
outputs_info
=
[
x0
,
None
],
n_steps
=
n_steps
,
)
# Save memory is triggered by choosing only last value
y
=
ys
[
-
1
]
z
=
zs
[
-
1
]
f
=
pytensor
.
function
(
[
x0
,
seq
,
n_steps
],
[
y
,
z
],
mode
=
get_default_mode
()
.
including
(
"scan"
)
)
test_seq
=
np
.
zeros
(
200
,
dtype
=
config
.
floatX
)
np
.
testing
.
assert_allclose
(
f
(
x0
=
0
,
seq
=
test_seq
,
n_steps
=
200
),
100
)
np
.
testing
.
assert_allclose
(
f
(
x0
=
1
,
seq
=
test_seq
,
n_steps
=
20
),
21
)
np
.
testing
.
assert_allclose
(
f
(
x0
=
np
.
e
,
seq
=
test_seq
,
n_steps
=
1
),
np
.
e
+
1
)
with
pytest
.
raises
(
AssertionError
,
match
=
"n_steps > 0"
):
f
(
x0
=
0
,
seq
=
test_seq
,
n_steps
=
0
)
# Evaluate the shape of ys_trace and len_zs to confirm the rewrite worked correctly.
# If a MissingInputError is raised, it means the rewrite failed
[
scan_node
]
=
(
n
for
n
in
f
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
n
.
op
,
Scan
))
_
,
_
,
ys_trace
,
len_zs
=
scan_node
.
inputs
debug_fn
=
pytensor
.
function
(
[
n_steps
],
[
ys_trace
.
shape
[
0
],
len_zs
],
accept_inplace
=
True
)
stored_ys_steps
,
stored_zs_steps
=
debug_fn
(
n_steps
=
200
)
assert
stored_ys_steps
==
2
assert
stored_zs_steps
==
1
def
test_inner_replace_dot
():
def
test_inner_replace_dot
():
"""
"""
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论