Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
88cc33b6
提交
88cc33b6
authored
3月 02, 2023
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
4月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Fix Scan JAX dispatcher
上级
7b609047
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
2 个修改的文件
包含
188 行增加
和
14 行删除
+188
-14
scan.py
pytensor/link/jax/dispatch/scan.py
+0
-0
test_scan.py
tests/link/jax/test_scan.py
+188
-14
没有找到文件。
pytensor/link/jax/dispatch/scan.py
浏览文件 @
88cc33b6
差异被折叠。
点击展开。
tests/link/jax/test_scan.py
浏览文件 @
88cc33b6
import
re
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
packaging.version
import
parse
as
version_parse
import
pytensor.tensor
as
at
import
pytensor.tensor
as
at
from
pytensor
import
function
,
shared
from
pytensor.compile
import
get_mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.scan
import
until
from
pytensor.scan.basic
import
scan
from
pytensor.scan.basic
import
scan
from
pytensor.scan.op
import
Scan
from
pytensor.tensor
import
random
from
pytensor.tensor.math
import
gammaln
,
log
from
pytensor.tensor.math
import
gammaln
,
log
from
pytensor.tensor.type
import
ivector
,
lscalar
,
scala
r
from
pytensor.tensor.type
import
lscalar
,
scalar
,
vecto
r
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
jax
=
pytest
.
importorskip
(
"jax"
)
jax
=
pytest
.
importorskip
(
"jax"
)
@pytest.mark.xfail
(
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
2
,
None
,
None
)])
version_parse
(
jax
.
__version__
)
>=
version_parse
(
"0.2.12"
),
def
test_scan_sit_sot
(
view
):
reason
=
"Omnistaging cannot be disabled"
,
x0
=
at
.
scalar
(
"x0"
,
dtype
=
"float64"
)
)
xs
,
_
=
scan
(
def
test_jax_scan_multiple_output
():
lambda
xtm1
:
xtm1
+
1
,
outputs_info
=
[
x0
],
n_steps
=
10
,
)
if
view
:
xs
=
xs
[
view
]
fg
=
FunctionGraph
([
x0
],
[
xs
])
test_input_vals
=
[
np
.
e
]
compare_jax_and_py
(
fg
,
test_input_vals
)
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
def
test_scan_mit_sot
(
view
):
x0
=
at
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
3
,))
xs
,
_
=
scan
(
lambda
xtm3
,
xtm1
:
xtm3
+
xtm1
+
1
,
outputs_info
=
[{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]}],
n_steps
=
10
,
)
if
view
:
xs
=
xs
[
view
]
fg
=
FunctionGraph
([
x0
],
[
xs
])
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
)]
compare_jax_and_py
(
fg
,
test_input_vals
)
@pytest.mark.parametrize
(
"view_x"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
@pytest.mark.parametrize
(
"view_y"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
def
test_scan_multiple_mit_sot
(
view_x
,
view_y
):
x0
=
at
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
3
,))
y0
=
at
.
vector
(
"y0"
,
dtype
=
"float64"
,
shape
=
(
4
,))
def
step
(
xtm3
,
xtm1
,
ytm4
,
ytm2
):
return
xtm3
+
ytm4
+
1
,
xtm1
+
ytm2
+
2
[
xs
,
ys
],
_
=
scan
(
fn
=
step
,
outputs_info
=
[
{
"initial"
:
x0
,
"taps"
:
[
-
3
,
-
1
]},
{
"initial"
:
y0
,
"taps"
:
[
-
4
,
-
2
]},
],
n_steps
=
10
,
)
if
view_x
:
xs
=
xs
[
view_x
]
if
view_y
:
ys
=
ys
[
view_y
]
fg
=
FunctionGraph
([
x0
,
y0
],
[
xs
,
ys
])
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
),
np
.
full
((
4
,),
np
.
pi
)]
compare_jax_and_py
(
fg
,
test_input_vals
)
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
2
,),
slice
(
None
,
None
,
2
)])
def
test_scan_nit_sot
(
view
):
rng
=
np
.
random
.
default_rng
(
seed
=
49
)
xs
=
at
.
vector
(
"x0"
,
dtype
=
"float64"
,
shape
=
(
10
,))
ys
,
_
=
scan
(
lambda
x
:
at
.
exp
(
x
),
outputs_info
=
[
None
],
sequences
=
[
xs
],
)
if
view
:
ys
=
ys
[
view
]
fg
=
FunctionGraph
([
xs
],
[
ys
])
test_input_vals
=
[
rng
.
normal
(
size
=
10
)]
# We need to remove pushout rewrites, or the whole scan would just be
# converted to an Elemwise on xs
jax_fn
,
_
=
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan_pushout"
)
)
scan_nodes
=
[
node
for
node
in
jax_fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
]
assert
len
(
scan_nodes
)
==
1
@pytest.mark.xfail
(
raises
=
NotImplementedError
)
def
test_scan_mit_mot
():
xs
=
at
.
vector
(
"xs"
,
shape
=
(
10
,))
ys
,
_
=
scan
(
lambda
xtm2
,
xtm1
:
(
xtm2
+
xtm1
),
outputs_info
=
[{
"initial"
:
xs
,
"taps"
:
[
-
2
,
-
1
]}],
n_steps
=
10
,
)
grads_wrt_xs
=
at
.
grad
(
ys
.
sum
(),
wrt
=
xs
)
fg
=
FunctionGraph
([
xs
],
[
grads_wrt_xs
])
compare_jax_and_py
(
fg
,
[
np
.
arange
(
10
)])
def
test_scan_update
():
sh_static
=
shared
(
np
.
array
(
0.0
),
name
=
"sh_static"
)
sh_update
=
shared
(
np
.
array
(
1.0
),
name
=
"sh_update"
)
xs
,
update
=
scan
(
lambda
sh_static
,
sh_update
:
(
sh_static
+
sh_update
,
{
sh_update
:
sh_update
*
2
},
),
outputs_info
=
[
None
],
non_sequences
=
[
sh_static
,
sh_update
],
strict
=
True
,
n_steps
=
7
,
)
jax_fn
=
function
([],
xs
,
updates
=
update
,
mode
=
"JAX"
)
np
.
testing
.
assert_array_equal
(
jax_fn
(),
np
.
array
([
1
,
2
,
4
,
8
,
16
,
32
,
64
])
+
0.0
)
sh_static
.
set_value
(
1.0
)
np
.
testing
.
assert_array_equal
(
jax_fn
(),
np
.
array
([
128
,
256
,
512
,
1024
,
2048
,
4096
,
8192
])
+
1.0
)
sh_static
.
set_value
(
2.0
)
sh_update
.
set_value
(
1.0
)
np
.
testing
.
assert_array_equal
(
jax_fn
(),
np
.
array
([
1
,
2
,
4
,
8
,
16
,
32
,
64
])
+
2.0
)
def
test_scan_rng_update
():
rng
=
shared
(
np
.
random
.
default_rng
(
190
),
name
=
"rng"
)
def
update_fn
(
rng
):
new_rng
,
x
=
random
.
normal
(
rng
=
rng
)
.
owner
.
outputs
return
x
,
{
rng
:
new_rng
}
xs
,
update
=
scan
(
update_fn
,
outputs_info
=
[
None
],
non_sequences
=
[
rng
],
strict
=
True
,
n_steps
=
10
,
)
# Without updates
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"[rng] will not be used in the compiled JAX graph"
),
):
jax_fn
=
function
([],
[
xs
],
updates
=
None
,
mode
=
"JAX"
)
res1
,
res2
=
jax_fn
(),
jax_fn
()
assert
np
.
unique
(
res1
)
.
size
==
10
assert
np
.
unique
(
res2
)
.
size
==
10
np
.
testing
.
assert_array_equal
(
res1
,
res2
)
# With updates
with
pytest
.
warns
(
UserWarning
,
match
=
re
.
escape
(
"[rng] will not be used in the compiled JAX graph"
),
):
jax_fn
=
function
([],
[
xs
],
updates
=
update
,
mode
=
"JAX"
)
res1
,
res2
=
jax_fn
(),
jax_fn
()
assert
np
.
unique
(
res1
)
.
size
==
10
assert
np
.
unique
(
res2
)
.
size
==
10
assert
np
.
all
(
np
.
not_equal
(
res1
,
res2
))
@pytest.mark.xfail
(
raises
=
NotImplementedError
)
def
test_scan_while
():
xs
,
_
=
scan
(
lambda
x
:
(
x
+
1
,
until
(
x
<
10
)),
outputs_info
=
[
at
.
zeros
(())],
n_steps
=
100
,
)
fg
=
FunctionGraph
([],
[
xs
])
compare_jax_and_py
(
fg
,
[])
def
test_scan_SEIR
():
"""Test a scan implementation of a SEIR model.
"""Test a scan implementation of a SEIR model.
SEIR model definition:
SEIR model definition:
...
@@ -38,8 +216,8 @@ def test_jax_scan_multiple_output():
...
@@ -38,8 +216,8 @@ def test_jax_scan_multiple_output():
return
binomln
(
n
,
value
)
+
value
*
log
(
p
)
+
(
n
-
value
)
*
log
(
1
-
p
)
return
binomln
(
n
,
value
)
+
value
*
log
(
p
)
+
(
n
-
value
)
*
log
(
1
-
p
)
# sequences
# sequences
at_C
=
ivector
(
"C_t"
)
at_C
=
vector
(
"C_t"
,
dtype
=
"int32"
,
shape
=
(
8
,)
)
at_D
=
ivector
(
"D_t"
)
at_D
=
vector
(
"D_t"
,
dtype
=
"int32"
,
shape
=
(
8
,)
)
# outputs_info (initial conditions)
# outputs_info (initial conditions)
st0
=
lscalar
(
"s_t0"
)
st0
=
lscalar
(
"s_t0"
)
et0
=
lscalar
(
"e_t0"
)
et0
=
lscalar
(
"e_t0"
)
...
@@ -108,11 +286,7 @@ def test_jax_scan_multiple_output():
...
@@ -108,11 +286,7 @@ def test_jax_scan_multiple_output():
compare_jax_and_py
(
out_fg
,
test_input_vals
)
compare_jax_and_py
(
out_fg
,
test_input_vals
)
@pytest.mark.xfail
(
def
test_scan_mitsot_with_nonseq
():
version_parse
(
jax
.
__version__
)
>=
version_parse
(
"0.2.12"
),
reason
=
"Omnistaging cannot be disabled"
,
)
def
test_jax_scan_tap_output
():
a_at
=
scalar
(
"a"
)
a_at
=
scalar
(
"a"
)
def
input_step_fn
(
y_tm1
,
y_tm3
,
a
):
def
input_step_fn
(
y_tm1
,
y_tm3
,
a
):
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论