Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
88cc33b6
提交
88cc33b6
authored
3月 02, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
4月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix Scan JAX dispatcher
上级
7b609047
显示空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
343 行增加
和
135 行删除
+343
-135
scan.py
pytensor/link/jax/dispatch/scan.py
+155
-121
test_scan.py
tests/link/jax/test_scan.py
+188
-14
没有找到文件。
pytensor/link/jax/dispatch/scan.py
浏览文件 @
88cc33b6
import
jax
import
jax
import
jax.numpy
as
jnp
import
jax.numpy
as
jnp
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.link.jax.dispatch.basic
import
jax_funcify
from
pytensor.scan.op
import
Scan
from
pytensor.scan.op
import
Scan
from
pytensor.scan.utils
import
ScanArgs
@jax_funcify.register
(
Scan
)
@jax_funcify.register
(
Scan
)
def
jax_funcify_Scan
(
op
,
**
kwargs
):
def
jax_funcify_Scan
(
op
:
Scan
,
**
kwargs
):
inner_fg
=
FunctionGraph
(
op
.
inputs
,
op
.
outputs
)
info
=
op
.
info
jax_at_inner_func
=
jax_funcify
(
inner_fg
,
**
kwargs
)
def
scan
(
*
outer_inputs
):
if
info
.
as_while
:
scan_args
=
ScanArgs
(
raise
NotImplementedError
(
"While Scan cannot yet be converted to JAX"
)
list
(
outer_inputs
),
[
None
]
*
op
.
info
.
n_outs
,
op
.
inputs
,
op
.
outputs
,
op
.
info
if
info
.
n_mit_mot
:
raise
NotImplementedError
(
"Scan with MIT-MOT (gradients of scan) cannot yet be converted to JAX"
)
)
# `outer_inputs` is a list with the following composite form:
# Optimize inner graph
# [n_steps]
rewriter
=
op
.
mode_instance
.
optimizer
# + outer_in_seqs
rewriter
(
op
.
fgraph
)
# + outer_in_mit_mot
scan_inner_func
=
jax_funcify
(
op
.
fgraph
,
**
kwargs
)
# + outer_in_mit_sot
# + outer_in_sit_sot
def
scan
(
*
outer_inputs
):
# + outer_in_shared
# Extract JAX scan inputs
# + outer_in_nit_sot
outer_inputs
=
list
(
outer_inputs
)
# + outer_in_non_seqs
n_steps
=
outer_inputs
[
0
]
# JAX `length`
n_steps
=
scan_args
.
n_steps
seqs
=
op
.
outer_seqs
(
outer_inputs
)
# JAX `xs`
seqs
=
scan_args
.
outer_in_seqs
mit_sot_init
=
[]
# TODO: mit_mots
for
tap
,
seq
in
zip
(
op
.
info
.
mit_sot_in_slices
,
op
.
outer_mitsot
(
outer_inputs
)):
mit_mot_in_slices
=
[]
init_slice
=
seq
[:
abs
(
min
(
tap
))]
mit_sot_init
.
append
(
init_slice
)
mit_sot_in_slices
=
[]
for
tap
,
seq
in
zip
(
scan_args
.
mit_sot_in_slices
,
scan_args
.
outer_in_mit_sot
):
sit_sot_init
=
[
seq
[
0
]
for
seq
in
op
.
outer_sitsot
(
outer_inputs
)]
neg_taps
=
[
abs
(
t
)
for
t
in
tap
if
t
<
0
]
pos_taps
=
[
abs
(
t
)
for
t
in
tap
if
t
>
0
]
max_neg
=
max
(
neg_taps
)
if
neg_taps
else
0
max_pos
=
max
(
pos_taps
)
if
pos_taps
else
0
init_slice
=
seq
[:
max_neg
+
max_pos
]
mit_sot_in_slices
.
append
(
init_slice
)
sit_sot_in_slices
=
[
seq
[
0
]
for
seq
in
scan_args
.
outer_in_sit_sot
]
init_carry
=
(
init_carry
=
(
mit_mot_in_slices
,
mit_sot_init
,
mit_sot_in_slices
,
sit_sot_init
,
sit_sot_in_slices
,
op
.
outer_shared
(
outer_inputs
),
scan_args
.
outer_in_shared
,
op
.
outer_non_seqs
(
outer_inputs
),
scan_args
.
outer_in_non_seqs
,
)
# JAX `init`
)
def
jax_args_to_inner_func_args
(
carry
,
x
):
"""Convert JAX scan arguments into format expected by scan_inner_func.
scan(carry, x) -> scan_inner_func(seqs, mit_sot, sit_sot, shared, non_seqs)
"""
def
jax_args_to_inner_scan
(
op
,
carry
,
x
):
# `carry` contains all inner taps, shared terms, and non_seqs
# `carry` contains all inner-output taps, non_seqs, and shared
# terms
(
(
inner_in_mit_mot
,
inner_mit_sot
,
inner_in_mit_sot
,
inner_sit_sot
,
inner_in_sit_sot
,
inner_shared
,
inner_in_shared
,
inner_non_seqs
,
inner_in_non_seqs
,
)
=
carry
)
=
carry
# `x` contains the in_seqs
# `x` contains the inner sequences
inner_in_seqs
=
x
inner_seqs
=
x
# `inner_scan_inputs` is a list with the following composite form:
mit_sot_flatten
=
[]
# inner_in_seqs
for
array
,
index
in
zip
(
inner_mit_sot
,
op
.
info
.
mit_sot_in_slices
):
# + sum(inner_in_mit_mot, [])
mit_sot_flatten
.
extend
(
array
[
jnp
.
array
(
index
)])
# + sum(inner_in_mit_sot, [])
# + inner_in_sit_sot
inner_scan_inputs
=
[
# + inner_in_shared
*
inner_seqs
,
# + inner_in_non_seqs
*
mit_sot_flatten
,
inner_in_mit_sot_flatten
=
[]
*
inner_sit_sot
,
for
array
,
index
in
zip
(
inner_in_mit_sot
,
scan_args
.
mit_sot_in_slices
):
*
inner_shared
,
inner_in_mit_sot_flatten
.
extend
(
array
[
jnp
.
array
(
index
)])
*
inner_non_seqs
,
]
inner_scan_inputs
=
sum
(
[
inner_in_seqs
,
inner_in_mit_mot
,
inner_in_mit_sot_flatten
,
inner_in_sit_sot
,
inner_in_shared
,
inner_in_non_seqs
,
],
[],
)
return
inner_scan_inputs
return
inner_scan_inputs
def
inner_scan_outs_to_jax_outs
(
def
inner_func_outs_to_jax_outs
(
op
,
old_carry
,
old_carry
,
inner_scan_outs
,
inner_scan_outs
,
):
):
"""Convert inner_scan_func outputs into format expected by JAX scan.
old_carry + (mit_sot_outs, sit_sot_outs, nit_sot_outs, shared_outs) -> (new_carry, ys)
"""
(
(
inner_in_mit_mot
,
inner_mit_sot
,
inner_in_mit_sot
,
inner_sit_sot
,
inner_in_sit_sot
,
inner_shared
,
inner_in_shared
,
inner_non_seqs
,
inner_in_non_seqs
,
)
=
old_carry
)
=
old_carry
def
update_mit_sot
(
mit_sot
,
new_val
):
inner_mit_sot_outs
=
op
.
inner_mitsot_outs
(
inner_scan_outs
)
return
jnp
.
concatenate
([
mit_sot
[
1
:],
new_val
[
None
,
...
]],
axis
=
0
)
inner_sit_sot_outs
=
op
.
inner_sitsot_outs
(
inner_scan_outs
)
inner_nit_sot_outs
=
op
.
inner_nitsot_outs
(
inner_scan_outs
)
inner_shared_outs
=
op
.
inner_shared_outs
(
inner_scan_outs
)
# Replace the oldest mit_sot tap by the newest value
inner_mit_sot_new
=
[
jnp
.
concatenate
([
old_mit_sot
[
1
:],
new_val
[
None
,
...
]],
axis
=
0
)
for
old_mit_sot
,
new_val
in
zip
(
inner_mit_sot
,
inner_mit_sot_outs
,
)
]
# Nothing needs to be done with sit_sot
inner_sit_sot_new
=
inner_sit_sot_outs
inner_shared_new
=
inner_shared
# Replace old shared inputs by new shared outputs
inner_shared_new
[:
len
(
inner_shared_outs
)]
=
inner_shared_outs
inner_out_mit_sot
=
[
new_carry
=
(
update_mit_sot
(
mit_sot
,
new_val
)
inner_mit_sot_new
,
for
mit_sot
,
new_val
in
zip
(
inner_in_mit_sot
,
inner_scan_outs
)
inner_sit_sot_new
,
inner_shared_new
,
inner_non_seqs
,
)
# Shared variables and non_seqs are not traced
traced_outs
=
[
*
inner_mit_sot_outs
,
*
inner_sit_sot_outs
,
*
inner_nit_sot_outs
,
]
]
# This should contain all inner-output taps, non_seqs, and shared
return
new_carry
,
traced_outs
# terms
if
not
inner_in_sit_sot
:
def
jax_inner_func
(
carry
,
x
):
inner_out_sit_sot
=
[]
inner_args
=
jax_args_to_inner_func_args
(
carry
,
x
)
inner_scan_outs
=
list
(
scan_inner_func
(
*
inner_args
))
new_carry
,
traced_outs
=
inner_func_outs_to_jax_outs
(
carry
,
inner_scan_outs
)
return
new_carry
,
traced_outs
# Extract PyTensor scan outputs
final_carry
,
traces
=
jax
.
lax
.
scan
(
jax_inner_func
,
init_carry
,
seqs
,
length
=
n_steps
)
def
get_partial_traces
(
traces
):
"""Convert JAX scan traces to PyTensor traces.
We need to:
1. Prepend initial states to JAX output traces
2. Slice final traces if Scan was instructed to only keep a portion
"""
init_states
=
mit_sot_init
+
sit_sot_init
+
[
None
]
*
op
.
info
.
n_nit_sot
buffers
=
(
op
.
outer_mitsot
(
outer_inputs
)
+
op
.
outer_sitsot
(
outer_inputs
)
+
op
.
outer_nitsot
(
outer_inputs
)
)
partial_traces
=
[]
for
init_state
,
trace
,
buffer
in
zip
(
init_states
,
traces
,
buffers
):
if
init_state
is
not
None
:
# MIT-SOT and SIT-SOT: The final output should be as long as the input buffer
full_trace
=
jnp
.
concatenate
(
[
jnp
.
atleast_1d
(
init_state
),
jnp
.
atleast_1d
(
trace
)],
axis
=
0
,
)
buffer_size
=
buffer
.
shape
[
0
]
else
:
else
:
inner_out_sit_sot
=
inner_scan_outs
# NIT-SOT: Buffer is just the number of entries that should be returned
new_carry
=
(
full_trace
=
jnp
.
atleast_1d
(
trace
)
inner_in_mit_mot
,
buffer_size
=
buffer
partial_trace
=
full_trace
[
-
buffer_size
:]
partial_traces
.
append
(
partial_trace
)
return
partial_traces
def
get_shared_outs
(
final_carry
):
"""Retrive last state of shared_outs from final_carry.
These outputs cannot be traced in PyTensor Scan
"""
(
inner_out_mit_sot
,
inner_out_mit_sot
,
inner_out_sit_sot
,
inner_out_sit_sot
,
inner_
in
_shared
,
inner_
out
_shared
,
inner_in_non_seqs
,
inner_in_non_seqs
,
)
)
=
final_carry
return
new_carry
shared_outs
=
inner_out_shared
[:
info
.
n_shared_outs
]
return
list
(
shared_outs
)
def
jax_inner_func
(
carry
,
x
):
scan_outs_final
=
get_partial_traces
(
traces
)
+
get_shared_outs
(
final_carry
)
inner_args
=
jax_args_to_inner_scan
(
op
,
carry
,
x
)
inner_scan_outs
=
list
(
jax_at_inner_func
(
*
inner_args
))
new_carry
=
inner_scan_outs_to_jax_outs
(
op
,
carry
,
inner_scan_outs
)
return
new_carry
,
inner_scan_outs
_
,
scan_out
=
jax
.
lax
.
scan
(
jax_inner_func
,
init_carry
,
seqs
,
length
=
n_steps
)
# We need to prepend the initial values so that the JAX output will
# match the raw `Scan` `Op` output and, thus, work with a downstream
# `Subtensor` `Op` introduced by the `scan` helper function.
def
append_scan_out
(
scan_in_part
,
scan_out_part
):
return
jnp
.
concatenate
([
scan_in_part
[:
-
n_steps
],
scan_out_part
],
axis
=
0
)
if
scan_args
.
outer_in_mit_sot
:
scan_out_final
=
[
append_scan_out
(
init
,
out
)
for
init
,
out
in
zip
(
scan_args
.
outer_in_mit_sot
,
scan_out
)
]
elif
scan_args
.
outer_in_sit_sot
:
scan_out_final
=
[
append_scan_out
(
init
,
out
)
for
init
,
out
in
zip
(
scan_args
.
outer_in_sit_sot
,
scan_out
)
]
if
len
(
scan_out_final
)
==
1
:
if
len
(
scan_out
s
_final
)
==
1
:
scan_out
_final
=
scan_out
_final
[
0
]
scan_out
s_final
=
scan_outs
_final
[
0
]
return
scan_out_final
return
scan_out
s
_final
return
scan
return
scan
tests/link/jax/test_scan.py
浏览文件 @
88cc33b6
import
re
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
packaging.version
import
parse
as
version_parse
import
pytensor.tensor
as
at
import
pytensor.tensor
as
at
from
pytensor
import
function
,
shared
from
pytensor.compile
import
get_mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.scan
import
until
from
pytensor.scan.basic
import
scan
from
pytensor.scan.basic
import
scan
from
pytensor.scan.op
import
Scan
from
pytensor.tensor
import
random
from
pytensor.tensor.math
import
gammaln
,
log
from
pytensor.tensor.math
import
gammaln
,
log
from
pytensor.tensor.type
import
ivector
,
lscalar
,
scala
r
from
pytensor.tensor.type
import
lscalar
,
scalar
,
vecto
r
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
jax
=
pytest
.
importorskip
(
"jax"
)
jax
=
pytest
.
importorskip
(
"jax"
)
@pytest.mark.xfail
(
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
2
,
None
,
None
)])
version_parse
(
jax
.
__version__
)
>=
version_parse
(
"0.2.12"
),
def
test_scan_sit_sot
(
view
):
reason
=
"Omnistaging cannot be disabled"
,
x0
=
at
.
scalar
(
"x0"
,
dtype
=
"float64"
)
)
xs
,
_
=
scan
(
def
test_jax_scan_multiple_output
():
lambda
xtm1
:
xtm1
+
1
,
outputs_info
=
[
x0
],
n_steps
=
10
,
)
if
view
:
xs
=
xs
[
view
]
fg
=
FunctionGraph
([
x0
],
[
xs
])
test_input_vals
=
[
np
.
e
]
compare_jax_and_py
(
fg
,
test_input_vals
)
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
def
test_scan_mit_sot
(
view
):
x0
=
at
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
3
,))
xs
,
_
=
scan
(
lambda
xtm3
,
xtm1
:
xtm3
+
xtm1
+
1
,
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
n_steps
=
10
,
)
if
view
:
xs
=
xs
[
view
]
fg
=
FunctionGraph
([
x0
],
[
xs
])
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
)]
compare_jax_and_py
(
fg
,
test_input_vals
)
@pytest.mark.parametrize
(
"view_x"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
@pytest.mark.parametrize
(
"view_y"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
def
test_scan_multiple_mit_sot
(
view_x
,
view_y
):
x0
=
at
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
3
,))
y0
=
at
.
vector
(
"y0"
,
dtype
=
"float64"
,
shape
=
(
4
,))
def
step
(
xtm3
,
xtm1
,
ytm4
,
ytm2
):
return
xtm3
+
ytm4
+
1
,
xtm1
+
ytm2
+
2
[
xs
,
ys
],
_
=
scan
(
fn
=
step
,
outputs_info
=
[
{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]},
{
"initial"
:
y0
,
"taps"
:
[
-
4
,
-
2
]},
],
n_steps
=
10
,
)
if
view_x
:
xs
=
xs
[
view_x
]
if
view_y
:
ys
=
ys
[
view_y
]
fg
=
FunctionGraph
([
x0
,
y0
],
[
xs
,
ys
])
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
),
np
.
full
((
4
,),
np
.
pi
)]
compare_jax_and_py
(
fg
,
test_input_vals
)
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
2
,),
slice
(
None
,
None
,
2
)])
def
test_scan_nit_sot
(
view
):
rng
=
np
.
random
.
default_rng
(
seed
=
49
)
xs
=
at
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
10
,))
ys
,
_
=
scan
(
lambda
x
:
at
.
exp
(
x
),
outputs_info
=
[
None
],
sequences
=
[
xs
],
)
if
view
:
ys
=
ys
[
view
]
fg
=
FunctionGraph
([
xs
],
[
ys
])
test_input_vals
=
[
rng
.
normal
(
size
=
10
)]
# We need to remove pushout rewrites, or the whole scan would just be
# converted to an Elemwise on xs
jax_fn
,
_
=
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan_pushout"
)
)
scan_nodes
=
[
node
for
node
in
jax_fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
]
assert
len
(
scan_nodes
)
==
1
@pytest.mark.xfail
(
raises
=
NotImplementedError
)
def
test_scan_mit_mot
():
xs
=
at
.
vector
(
"xs"
,
shape
=
(
10
,))
ys
,
_
=
scan
(
lambda
xtm2
,
xtm1
:
(
xtm2
+
xtm1
),
outputs_info
=
[{
"initial"
:
xs
,
"taps"
:
[
-
2
,
-
1
]}],
n_steps
=
10
,
)
grads_wrt_xs
=
at
.
grad
(
ys
.
sum
(),
wrt
=
xs
)
fg
=
FunctionGraph
([
xs
],
[
grads_wrt_xs
])
compare_jax_and_py
(
fg
,
[
np
.
arange
(
10
)])
def
test_scan_update
():
sh_static
=
shared
(
np
.
array
(
0.0
),
name
=
"sh_static"
)
sh_update
=
shared
(
np
.
array
(
1.0
),
name
=
"sh_update"
)
xs
,
update
=
scan
(
lambda
sh_static
,
sh_update
:
(
sh_static
+
sh_update
,
{
sh_update
:
sh_update
*
2
},
),
outputs_info
=
[
None
],
non_sequences
=
[
sh_static
,
sh_update
],
strict
=
True
,
n_steps
=
7
,
)
jax_fn
=
function
([],
xs
,
updates
=
update
,
mode
=
"JAX"
)
np
.
testing
.
assert_array_equal
(
jax_fn
(),
np
.
array
([
1
,
2
,
4
,
8
,
16
,
32
,
64
])
+
0.0
)
sh_static
.
set_value
(
1.0
)
np
.
testing
.
assert_array_equal
(
jax_fn
(),
np
.
array
([
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
])
+
1.0
)
sh_static
.
set_value
(
2.0
)
sh_update
.
set_value
(
1.0
)
np
.
testing
.
assert_array_equal
(
jax_fn
(),
np
.
array
([
1
,
2
,
4
,
8
,
16
,
32
,
64
])
+
2.0
)
def
test_scan_rng_update
():
rng
=
shared
(
np
.
random
.
default_rng
(
190
),
name
=
"rng"
)
def
update_fn
(
rng
):
new_rng
,
x
=
random
.
normal
(
rng
=
rng
)
.
owner
.
outputs
return
x
,
{
rng
:
new_rng
}
xs
,
update
=
scan
(
update_fn
,
outputs_info
=
[
None
],
non_sequences
=
[
rng
],
strict
=
True
,
n_steps
=
10
,
)
# Without updates
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"[rng] will not be used in the compiled JAX graph"
),
):
jax_fn
=
function
([],
[
xs
],
updates
=
None
,
mode
=
"JAX"
)
res1
,
res2
=
jax_fn
(),
jax_fn
()
assert
np
.
unique
(
res1
)
.
size
==
10
assert
np
.
unique
(
res2
)
.
size
==
10
np
.
testing
.
assert_array_equal
(
res1
,
res2
)
# With updates
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"[rng] will not be used in the compiled JAX graph"
),
):
jax_fn
=
function
([],
[
xs
],
updates
=
update
,
mode
=
"JAX"
)
res1
,
res2
=
jax_fn
(),
jax_fn
()
assert
np
.
unique
(
res1
)
.
size
==
10
assert
np
.
unique
(
res2
)
.
size
==
10
assert
np
.
all
(
np
.
not_equal
(
res1
,
res2
))
@pytest.mark.xfail
(
raises
=
NotImplementedError
)
def
test_scan_while
():
xs
,
_
=
scan
(
lambda
x
:
(
x
+
1
,
until
(
x
<
10
)),
outputs_info
=
[
at
.
zeros
(())],
n_steps
=
100
,
)
fg
=
FunctionGraph
([],
[
xs
])
compare_jax_and_py
(
fg
,
[])
def
test_scan_SEIR
():
"""Test a scan implementation of a SEIR model.
"""Test a scan implementation of a SEIR model.
SEIR model definition:
SEIR model definition:
...
@@ -38,8 +216,8 @@ def test_jax_scan_multiple_output():
...
@@ -38,8 +216,8 @@ def test_jax_scan_multiple_output():
return
binomln
(
n
,
value
)
+
value
*
log
(
p
)
+
(
n
-
value
)
*
log
(
1
-
p
)
return
binomln
(
n
,
value
)
+
value
*
log
(
p
)
+
(
n
-
value
)
*
log
(
1
-
p
)
# sequences
# sequences
at_C
=
ivector
(
"C_t"
)
at_C
=
vector
(
"C_t"
,
dtype
=
"int32"
,
shape
=
(
8
,)
)
at_D
=
ivector
(
"D_t"
)
at_D
=
vector
(
"D_t"
,
dtype
=
"int32"
,
shape
=
(
8
,)
)
# outputs_info (initial conditions)
# outputs_info (initial conditions)
st0
=
lscalar
(
"s_t0"
)
st0
=
lscalar
(
"s_t0"
)
et0
=
lscalar
(
"e_t0"
)
et0
=
lscalar
(
"e_t0"
)
...
@@ -108,11 +286,7 @@ def test_jax_scan_multiple_output():
...
@@ -108,11 +286,7 @@ def test_jax_scan_multiple_output():
compare_jax_and_py
(
out_fg
,
test_input_vals
)
compare_jax_and_py
(
out_fg
,
test_input_vals
)
@pytest.mark.xfail
(
def
test_scan_mitsot_with_nonseq
():
version_parse
(
jax
.
__version__
)
>=
version_parse
(
"0.2.12"
),
reason
=
"Omnistaging cannot be disabled"
,
)
def
test_jax_scan_tap_output
():
a_at
=
scalar
(
"a"
)
a_at
=
scalar
(
"a"
)
def
input_step_fn
(
y_tm1
,
y_tm3
,
a
):
def
input_step_fn
(
y_tm1
,
y_tm3
,
a
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论