Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cc8c4992
提交
cc8c4992
authored
11月 28, 2024
作者:
Adv
提交者:
Ricardo Vieira
2月 18, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Stop using FunctionGraph and tag.test_value in linker tests
Co-authored-by:
Adv
<
adhvaithhundi.221ds003@nitk.edu.in
>
上级
51ea1a0b
全部展开
隐藏空白字符变更
内嵌
并排
正在显示
41 个修改的文件
包含
388 行增加
和
543 行删除
+388
-543
test_basic.py
tests/link/jax/test_basic.py
+24
-27
test_blas.py
tests/link/jax/test_blas.py
+5
-8
test_blockwise.py
tests/link/jax/test_blockwise.py
+1
-3
test_einsum.py
tests/link/jax/test_einsum.py
+2
-5
test_elemwise.py
tests/link/jax/test_elemwise.py
+23
-33
test_extra_ops.py
tests/link/jax/test_extra_ops.py
+12
-28
test_math.py
tests/link/jax/test_math.py
+19
-15
test_nlinalg.py
tests/link/jax/test_nlinalg.py
+8
-17
test_pad.py
tests/link/jax/test_pad.py
+2
-3
test_random.py
tests/link/jax/test_random.py
+0
-0
test_scalar.py
tests/link/jax/test_scalar.py
+56
-70
test_scan.py
tests/link/jax/test_scan.py
+18
-32
test_shape.py
tests/link/jax/test_shape.py
+15
-24
test_slinalg.py
tests/link/jax/test_slinalg.py
+26
-31
test_sort.py
tests/link/jax/test_sort.py
+1
-3
test_sparse.py
tests/link/jax/test_sparse.py
+1
-3
test_subtensor.py
tests/link/jax/test_subtensor.py
+59
-59
test_tensor_basic.py
tests/link/jax/test_tensor_basic.py
+31
-43
test_basic.py
tests/link/numba/test_basic.py
+0
-0
test_blockwise.py
tests/link/numba/test_blockwise.py
+2
-1
test_elemwise.py
tests/link/numba/test_elemwise.py
+0
-0
test_extra_ops.py
tests/link/numba/test_extra_ops.py
+0
-0
test_nlinalg.py
tests/link/numba/test_nlinalg.py
+50
-90
test_pad.py
tests/link/numba/test_pad.py
+2
-3
test_random.py
tests/link/numba/test_random.py
+0
-0
test_scalar.py
tests/link/numba/test_scalar.py
+31
-45
test_scan.py
tests/link/numba/test_scan.py
+0
-0
test_slinalg.py
tests/link/numba/test_slinalg.py
+0
-0
test_sparse.py
tests/link/numba/test_sparse.py
+0
-0
test_subtensor.py
tests/link/numba/test_subtensor.py
+0
-0
test_tensor_basic.py
tests/link/numba/test_tensor_basic.py
+0
-0
test_basic.py
tests/link/pytorch/test_basic.py
+0
-0
test_blas.py
tests/link/pytorch/test_blas.py
+0
-0
test_elemwise.py
tests/link/pytorch/test_elemwise.py
+0
-0
test_extra_ops.py
tests/link/pytorch/test_extra_ops.py
+0
-0
test_math.py
tests/link/pytorch/test_math.py
+0
-0
test_nlinalg.py
tests/link/pytorch/test_nlinalg.py
+0
-0
test_shape.py
tests/link/pytorch/test_shape.py
+0
-0
test_sort.py
tests/link/pytorch/test_sort.py
+0
-0
test_subtensor.py
tests/link/pytorch/test_subtensor.py
+0
-0
test_extra_ops.py
tests/tensor/test_extra_ops.py
+0
-0
没有找到文件。
tests/link/jax/test_basic.py
浏览文件 @
cc8c4992
...
@@ -7,12 +7,12 @@ import pytest
...
@@ -7,12 +7,12 @@ import pytest
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.compile.mode
import
JAX
,
Mode
from
pytensor.compile.mode
import
JAX
,
Mode
from
pytensor.compile.sharedvalue
import
SharedVariable
,
shared
from
pytensor.compile.sharedvalue
import
shared
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
Op
,
get_test_value
from
pytensor.graph.op
import
Op
from
pytensor.ifelse
import
ifelse
from
pytensor.ifelse
import
ifelse
from
pytensor.link.jax
import
JAXLinker
from
pytensor.link.jax
import
JAXLinker
from
pytensor.raise_op
import
assert_op
from
pytensor.raise_op
import
assert_op
...
@@ -34,25 +34,28 @@ py_mode = Mode(linker="py", optimizer=None)
...
@@ -34,25 +34,28 @@ py_mode = Mode(linker="py", optimizer=None)
def
compare_jax_and_py
(
def
compare_jax_and_py
(
fgraph
:
FunctionGraph
,
graph_inputs
:
Iterable
[
Variable
],
graph_outputs
:
Variable
|
Iterable
[
Variable
],
test_inputs
:
Iterable
,
test_inputs
:
Iterable
,
*
,
assert_fn
:
Callable
|
None
=
None
,
assert_fn
:
Callable
|
None
=
None
,
must_be_device_array
:
bool
=
True
,
must_be_device_array
:
bool
=
True
,
jax_mode
=
jax_mode
,
jax_mode
=
jax_mode
,
py_mode
=
py_mode
,
py_mode
=
py_mode
,
):
):
"""Function to compare python
graph
output and jax compiled output for testing equality
"""Function to compare python
function
output and jax compiled output for testing equality
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to
The inputs and outputs are then passed to this function which then compiles the given function in both
this function which then compiles the graphs in both jax and python, runs the calculation
jax and python, runs the calculation in both and checks if the results are the same
in both and checks if the results are the same
Parameters
Parameters
----------
----------
fgraph: FunctionGraph
graph_inputs:
PyTensor function Graph object
Symbolic inputs to the graph
outputs:
Symbolic outputs of the graph
test_inputs: iter
test_inputs: iter
Numerical inputs for testing the function
graph
Numerical inputs for testing the function
.
assert_fn: func, opt
assert_fn: func, opt
Assert function used to check for equality between python and jax. If not
Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose
provided uses np.testing.assert_allclose
...
@@ -68,8 +71,10 @@ def compare_jax_and_py(
...
@@ -68,8 +71,10 @@ def compare_jax_and_py(
if
assert_fn
is
None
:
if
assert_fn
is
None
:
assert_fn
=
partial
(
np
.
testing
.
assert_allclose
,
rtol
=
1e-4
)
assert_fn
=
partial
(
np
.
testing
.
assert_allclose
,
rtol
=
1e-4
)
fn_inputs
=
[
i
for
i
in
fgraph
.
inputs
if
not
isinstance
(
i
,
SharedVariable
)]
if
any
(
inp
.
owner
is
not
None
for
inp
in
graph_inputs
):
pytensor_jax_fn
=
function
(
fn_inputs
,
fgraph
.
outputs
,
mode
=
jax_mode
)
raise
ValueError
(
"Inputs must be root variables"
)
pytensor_jax_fn
=
function
(
graph_inputs
,
graph_outputs
,
mode
=
jax_mode
)
jax_res
=
pytensor_jax_fn
(
*
test_inputs
)
jax_res
=
pytensor_jax_fn
(
*
test_inputs
)
if
must_be_device_array
:
if
must_be_device_array
:
...
@@ -78,10 +83,10 @@ def compare_jax_and_py(
...
@@ -78,10 +83,10 @@ def compare_jax_and_py(
else
:
else
:
assert
isinstance
(
jax_res
,
jax
.
Array
)
assert
isinstance
(
jax_res
,
jax
.
Array
)
pytensor_py_fn
=
function
(
fn_inputs
,
fgraph
.
outputs
,
mode
=
py_mode
)
pytensor_py_fn
=
function
(
graph_inputs
,
graph_
outputs
,
mode
=
py_mode
)
py_res
=
pytensor_py_fn
(
*
test_inputs
)
py_res
=
pytensor_py_fn
(
*
test_inputs
)
if
len
(
fgraph
.
outputs
)
>
1
:
if
isinstance
(
graph_outputs
,
list
|
tuple
)
:
for
j
,
p
in
zip
(
jax_res
,
py_res
,
strict
=
True
):
for
j
,
p
in
zip
(
jax_res
,
py_res
,
strict
=
True
):
assert_fn
(
j
,
p
)
assert_fn
(
j
,
p
)
else
:
else
:
...
@@ -187,16 +192,14 @@ def test_jax_ifelse():
...
@@ -187,16 +192,14 @@ def test_jax_ifelse():
false_vals
=
np
.
r_
[
-
1
,
-
2
,
-
3
]
false_vals
=
np
.
r_
[
-
1
,
-
2
,
-
3
]
x
=
ifelse
(
np
.
array
(
True
),
true_vals
,
false_vals
)
x
=
ifelse
(
np
.
array
(
True
),
true_vals
,
false_vals
)
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
[],
[
x
]
,
[])
a
=
dscalar
(
"a"
)
a
=
dscalar
(
"a"
)
a
.
tag
.
test_value
=
np
.
array
(
0.2
,
dtype
=
config
.
floatX
)
a
_test
=
np
.
array
(
0.2
,
dtype
=
config
.
floatX
)
x
=
ifelse
(
a
<
0.5
,
true_vals
,
false_vals
)
x
=
ifelse
(
a
<
0.5
,
true_vals
,
false_vals
)
x_fg
=
FunctionGraph
([
a
],
[
x
])
# I.e. False
compare_jax_and_py
(
x_fg
,
[
get_test_value
(
i
)
for
i
in
x_fg
.
inputs
])
compare_jax_and_py
(
[
a
],
[
x
],
[
a_test
])
def
test_jax_checkandraise
():
def
test_jax_checkandraise
():
...
@@ -209,11 +212,6 @@ def test_jax_checkandraise():
...
@@ -209,11 +212,6 @@ def test_jax_checkandraise():
function
((
p
,),
res
,
mode
=
jax_mode
)
function
((
p
,),
res
,
mode
=
jax_mode
)
def
set_test_value
(
x
,
v
):
x
.
tag
.
test_value
=
v
return
x
def
test_OpFromGraph
():
def
test_OpFromGraph
():
x
,
y
,
z
=
matrices
(
"xyz"
)
x
,
y
,
z
=
matrices
(
"xyz"
)
ofg_1
=
OpFromGraph
([
x
,
y
],
[
x
+
y
],
inline
=
False
)
ofg_1
=
OpFromGraph
([
x
,
y
],
[
x
+
y
],
inline
=
False
)
...
@@ -221,10 +219,9 @@ def test_OpFromGraph():
...
@@ -221,10 +219,9 @@ def test_OpFromGraph():
o1
,
o2
=
ofg_2
(
y
,
z
)
o1
,
o2
=
ofg_2
(
y
,
z
)
out
=
ofg_1
(
x
,
o1
)
+
o2
out
=
ofg_1
(
x
,
o1
)
+
o2
out_fg
=
FunctionGraph
([
x
,
y
,
z
],
[
out
])
xv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
xv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
yv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
3
yv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
3
zv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
5
zv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
5
compare_jax_and_py
(
out_fg
,
[
xv
,
yv
,
zv
])
compare_jax_and_py
(
[
x
,
y
,
z
],
[
out
]
,
[
xv
,
yv
,
zv
])
tests/link/jax/test_blas.py
浏览文件 @
cc8c4992
...
@@ -4,8 +4,6 @@ import pytest
...
@@ -4,8 +4,6 @@ import pytest
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.compile.mode
import
Mode
from
pytensor.compile.mode
import
Mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
from
pytensor.link.jax
import
JAXLinker
from
pytensor.link.jax
import
JAXLinker
from
pytensor.tensor
import
blas
as
pt_blas
from
pytensor.tensor
import
blas
as
pt_blas
...
@@ -16,21 +14,20 @@ from tests.link.jax.test_basic import compare_jax_and_py
...
@@ -16,21 +14,20 @@ from tests.link.jax.test_basic import compare_jax_and_py
def
test_jax_BatchedDot
():
def
test_jax_BatchedDot
():
# tensor3 . tensor3
# tensor3 . tensor3
a
=
tensor3
(
"a"
)
a
=
tensor3
(
"a"
)
a
.
tag
.
test_value
=
(
a
_
test_value
=
(
np
.
linspace
(
-
1
,
1
,
10
*
5
*
3
)
.
astype
(
config
.
floatX
)
.
reshape
((
10
,
5
,
3
))
np
.
linspace
(
-
1
,
1
,
10
*
5
*
3
)
.
astype
(
config
.
floatX
)
.
reshape
((
10
,
5
,
3
))
)
)
b
=
tensor3
(
"b"
)
b
=
tensor3
(
"b"
)
b
.
tag
.
test_value
=
(
b
_
test_value
=
(
np
.
linspace
(
1
,
-
1
,
10
*
3
*
2
)
.
astype
(
config
.
floatX
)
.
reshape
((
10
,
3
,
2
))
np
.
linspace
(
1
,
-
1
,
10
*
3
*
2
)
.
astype
(
config
.
floatX
)
.
reshape
((
10
,
3
,
2
))
)
)
out
=
pt_blas
.
BatchedDot
()(
a
,
b
)
out
=
pt_blas
.
BatchedDot
()(
a
,
b
)
fgraph
=
FunctionGraph
([
a
,
b
],
[
out
])
compare_jax_and_py
([
a
,
b
],
[
out
],
[
a_test_value
,
b_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
# A dimension mismatch should raise a TypeError for compatibility
# A dimension mismatch should raise a TypeError for compatibility
inputs
=
[
get_test_value
(
a
)[:
-
1
],
get_test_value
(
b
)
]
inputs
=
[
a_test_value
[:
-
1
],
b_test_value
]
opts
=
RewriteDatabaseQuery
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
RewriteDatabaseQuery
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
pytensor_jax_fn
=
function
(
fgraph
.
inputs
,
fgraph
.
outputs
,
mode
=
jax_mode
)
pytensor_jax_fn
=
function
(
[
a
,
b
],
[
out
]
,
mode
=
jax_mode
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
pytensor_jax_fn
(
*
inputs
)
pytensor_jax_fn
(
*
inputs
)
tests/link/jax/test_blockwise.py
浏览文件 @
cc8c4992
...
@@ -2,7 +2,6 @@ import numpy as np
...
@@ -2,7 +2,6 @@ import numpy as np
import
pytest
import
pytest
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.graph
import
FunctionGraph
from
pytensor.tensor
import
tensor
from
pytensor.tensor
import
tensor
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.math
import
Dot
,
matmul
from
pytensor.tensor.math
import
Dot
,
matmul
...
@@ -32,8 +31,7 @@ def test_matmul(matmul_op):
...
@@ -32,8 +31,7 @@ def test_matmul(matmul_op):
out
=
matmul_op
(
a
,
b
)
out
=
matmul_op
(
a
,
b
)
assert
isinstance
(
out
.
owner
.
op
,
Blockwise
)
assert
isinstance
(
out
.
owner
.
op
,
Blockwise
)
fg
=
FunctionGraph
([
a
,
b
],
[
out
])
fn
,
_
=
compare_jax_and_py
([
a
,
b
],
[
out
],
test_values
)
fn
,
_
=
compare_jax_and_py
(
fg
,
test_values
)
# Check we are not adding any unnecessary stuff
# Check we are not adding any unnecessary stuff
jaxpr
=
str
(
jax
.
make_jaxpr
(
fn
.
vm
.
jit_fn
)(
*
test_values
))
jaxpr
=
str
(
jax
.
make_jaxpr
(
fn
.
vm
.
jit_fn
)(
*
test_values
))
...
...
tests/link/jax/test_einsum.py
浏览文件 @
cc8c4992
...
@@ -2,7 +2,6 @@ import numpy as np
...
@@ -2,7 +2,6 @@ import numpy as np
import
pytest
import
pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.graph
import
FunctionGraph
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -22,8 +21,7 @@ def test_jax_einsum():
...
@@ -22,8 +21,7 @@ def test_jax_einsum():
}
}
x_pt
,
y_pt
,
z_pt
=
(
pt
.
tensor
(
name
,
shape
=
shape
)
for
name
,
shape
in
shapes
.
items
())
x_pt
,
y_pt
,
z_pt
=
(
pt
.
tensor
(
name
,
shape
=
shape
)
for
name
,
shape
in
shapes
.
items
())
out
=
pt
.
einsum
(
subscripts
,
x_pt
,
y_pt
,
z_pt
)
out
=
pt
.
einsum
(
subscripts
,
x_pt
,
y_pt
,
z_pt
)
fg
=
FunctionGraph
([
x_pt
,
y_pt
,
z_pt
],
[
out
])
compare_jax_and_py
([
x_pt
,
y_pt
,
z_pt
],
[
out
],
[
x
,
y
,
z
])
compare_jax_and_py
(
fg
,
[
x
,
y
,
z
])
def
test_ellipsis_einsum
():
def
test_ellipsis_einsum
():
...
@@ -34,5 +32,4 @@ def test_ellipsis_einsum():
...
@@ -34,5 +32,4 @@ def test_ellipsis_einsum():
x_pt
=
pt
.
tensor
(
"x"
,
shape
=
x
.
shape
)
x_pt
=
pt
.
tensor
(
"x"
,
shape
=
x
.
shape
)
y_pt
=
pt
.
tensor
(
"y"
,
shape
=
y
.
shape
)
y_pt
=
pt
.
tensor
(
"y"
,
shape
=
y
.
shape
)
out
=
pt
.
einsum
(
subscripts
,
x_pt
,
y_pt
)
out
=
pt
.
einsum
(
subscripts
,
x_pt
,
y_pt
)
fg
=
FunctionGraph
([
x_pt
,
y_pt
],
[
out
])
compare_jax_and_py
([
x_pt
,
y_pt
],
[
out
],
[
x
,
y
])
compare_jax_and_py
(
fg
,
[
x
,
y
])
tests/link/jax/test_elemwise.py
浏览文件 @
cc8c4992
...
@@ -6,8 +6,6 @@ import pytensor
...
@@ -6,8 +6,6 @@ import pytensor
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.compile
import
get_mode
from
pytensor.compile
import
get_mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.tensor
import
elemwise
as
pt_elemwise
from
pytensor.tensor
import
elemwise
as
pt_elemwise
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
prod
from
pytensor.tensor.math
import
prod
...
@@ -26,22 +24,22 @@ def test_jax_Dimshuffle():
...
@@ -26,22 +24,22 @@ def test_jax_Dimshuffle():
a_pt
=
matrix
(
"a"
)
a_pt
=
matrix
(
"a"
)
x
=
a_pt
.
T
x
=
a_pt
.
T
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
[
a_pt
],
[
x
],
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)]
)
x
=
a_pt
.
dimshuffle
([
0
,
1
,
"x"
])
x
=
a_pt
.
dimshuffle
([
0
,
1
,
"x"
])
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
[
a_pt
],
[
x
],
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)]
)
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
x
=
a_pt
.
dimshuffle
((
0
,))
x
=
a_pt
.
dimshuffle
((
0
,))
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
([
a_pt
],
[
x
],
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
x
=
pt_elemwise
.
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
0
,))(
a_pt
)
x
=
pt_elemwise
.
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
0
,))(
a_pt
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
([
a_pt
],
[
x
],
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
def
test_jax_CAReduce
():
def
test_jax_CAReduce
():
...
@@ -49,64 +47,58 @@ def test_jax_CAReduce():
...
@@ -49,64 +47,58 @@ def test_jax_CAReduce():
a_pt
.
tag
.
test_value
=
np
.
r_
[
1
,
2
,
3
]
.
astype
(
config
.
floatX
)
a_pt
.
tag
.
test_value
=
np
.
r_
[
1
,
2
,
3
]
.
astype
(
config
.
floatX
)
x
=
pt_sum
(
a_pt
,
axis
=
None
)
x
=
pt_sum
(
a_pt
,
axis
=
None
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1
,
2
,
3
]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
a_pt
],
[
x
]
,
[
np
.
r_
[
1
,
2
,
3
]
.
astype
(
config
.
floatX
)])
a_pt
=
matrix
(
"a"
)
a_pt
=
matrix
(
"a"
)
a_pt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)
a_pt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)
x
=
pt_sum
(
a_pt
,
axis
=
0
)
x
=
pt_sum
(
a_pt
,
axis
=
0
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
a_pt
],
[
x
]
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
x
=
pt_sum
(
a_pt
,
axis
=
1
)
x
=
pt_sum
(
a_pt
,
axis
=
1
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
a_pt
],
[
x
]
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
a_pt
=
matrix
(
"a"
)
a_pt
=
matrix
(
"a"
)
a_pt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)
a_pt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)
x
=
prod
(
a_pt
,
axis
=
0
)
x
=
prod
(
a_pt
,
axis
=
0
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
a_pt
],
[
x
]
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
x
=
pt_all
(
a_pt
)
x
=
pt_all
(
a_pt
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
a_pt
],
[
x
]
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
def
test_softmax
(
axis
):
def
test_softmax
(
axis
):
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
x
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
x
_
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
out
=
softmax
(
x
,
axis
=
axis
)
out
=
softmax
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
x_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
def
test_logsoftmax
(
axis
):
def
test_logsoftmax
(
axis
):
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
x
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
x
_
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
out
=
log_softmax
(
x
,
axis
=
axis
)
out
=
log_softmax
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
compare_jax_and_py
(
[
x
],
[
out
],
[
x_test_value
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
def
test_softmax_grad
(
axis
):
def
test_softmax_grad
(
axis
):
dy
=
matrix
(
"dy"
)
dy
=
matrix
(
"dy"
)
dy
.
tag
.
test_value
=
np
.
array
([[
1
,
1
,
1
],
[
0
,
0
,
0
]],
dtype
=
config
.
floatX
)
dy
_
test_value
=
np
.
array
([[
1
,
1
,
1
],
[
0
,
0
,
0
]],
dtype
=
config
.
floatX
)
sm
=
matrix
(
"sm"
)
sm
=
matrix
(
"sm"
)
sm
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
sm
_
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
out
=
SoftmaxGrad
(
axis
=
axis
)(
dy
,
sm
)
out
=
SoftmaxGrad
(
axis
=
axis
)(
dy
,
sm
)
fgraph
=
FunctionGraph
([
dy
,
sm
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
compare_jax_and_py
(
[
dy
,
sm
],
[
out
],
[
dy_test_value
,
sm_test_value
])
@pytest.mark.parametrize
(
"size"
,
[(
10
,
10
),
(
1000
,
1000
)])
@pytest.mark.parametrize
(
"size"
,
[(
10
,
10
),
(
1000
,
1000
)])
...
@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
...
@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
def
test_multiple_input_multiply
():
def
test_multiple_input_multiply
():
x
,
y
,
z
=
vectors
(
"xyz"
)
x
,
y
,
z
=
vectors
(
"xyz"
)
out
=
pt
.
mul
(
x
,
y
,
z
)
out
=
pt
.
mul
(
x
,
y
,
z
)
compare_jax_and_py
([
x
,
y
,
z
],
[
out
],
test_inputs
=
[[
1.5
],
[
2.5
],
[
3.5
]])
fg
=
FunctionGraph
(
outputs
=
[
out
],
clone
=
False
)
compare_jax_and_py
(
fg
,
[[
1.5
],
[
2.5
],
[
3.5
]])
tests/link/jax/test_extra_ops.py
浏览文件 @
cc8c4992
...
@@ -3,8 +3,6 @@ import pytest
...
@@ -3,8 +3,6 @@ import pytest
import
pytensor.tensor.basic
as
ptb
import
pytensor.tensor.basic
as
ptb
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.tensor
import
extra_ops
as
pt_extra_ops
from
pytensor.tensor
import
extra_ops
as
pt_extra_ops
from
pytensor.tensor.sort
import
argsort
from
pytensor.tensor.sort
import
argsort
from
pytensor.tensor.type
import
matrix
,
tensor
from
pytensor.tensor.type
import
matrix
,
tensor
...
@@ -19,57 +17,45 @@ def test_extra_ops():
...
@@ -19,57 +17,45 @@ def test_extra_ops():
a_test
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
a_test
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
out
=
pt_extra_ops
.
cumsum
(
a
,
axis
=
0
)
out
=
pt_extra_ops
.
cumsum
(
a
,
axis
=
0
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
out
=
pt_extra_ops
.
cumprod
(
a
,
axis
=
1
)
out
=
pt_extra_ops
.
cumprod
(
a
,
axis
=
1
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
out
=
pt_extra_ops
.
diff
(
a
,
n
=
2
,
axis
=
1
)
out
=
pt_extra_ops
.
diff
(
a
,
n
=
2
,
axis
=
1
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
out
=
pt_extra_ops
.
repeat
(
a
,
(
3
,
3
),
axis
=
1
)
out
=
pt_extra_ops
.
repeat
(
a
,
(
3
,
3
),
axis
=
1
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
c
=
ptb
.
as_tensor
(
5
)
c
=
ptb
.
as_tensor
(
5
)
out
=
pt_extra_ops
.
fill_diagonal
(
a
,
c
)
out
=
pt_extra_ops
.
fill_diagonal
(
a
,
c
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
out
=
pt_extra_ops
.
fill_diagonal_offset
(
a
,
c
,
c
)
out
=
pt_extra_ops
.
fill_diagonal_offset
(
a
,
c
,
c
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
out
=
pt_extra_ops
.
Unique
(
axis
=
1
)(
a
)
out
=
pt_extra_ops
.
Unique
(
axis
=
1
)(
a
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
indices
=
np
.
arange
(
np
.
prod
((
3
,
4
)))
indices
=
np
.
arange
(
np
.
prod
((
3
,
4
)))
out
=
pt_extra_ops
.
unravel_index
(
indices
,
(
3
,
4
),
order
=
"C"
)
out
=
pt_extra_ops
.
unravel_index
(
indices
,
(
3
,
4
),
order
=
"C"
)
fgraph
=
FunctionGraph
([],
out
)
compare_jax_and_py
([],
out
,
[],
must_be_device_array
=
False
)
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
],
must_be_device_array
=
False
)
v
=
ptb
.
as_tensor_variable
(
6.0
)
v
=
ptb
.
as_tensor_variable
(
6.0
)
sorted_idx
=
argsort
(
a
.
ravel
())
sorted_idx
=
argsort
(
a
.
ravel
())
out
=
pt_extra_ops
.
searchsorted
(
a
.
ravel
()[
sorted_idx
],
v
)
out
=
pt_extra_ops
.
searchsorted
(
a
.
ravel
()[
sorted_idx
],
v
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
def
test_bartlett_dynamic_shape
():
def
test_bartlett_dynamic_shape
():
c
=
tensor
(
shape
=
(),
dtype
=
int
)
c
=
tensor
(
shape
=
(),
dtype
=
int
)
out
=
pt_extra_ops
.
bartlett
(
c
)
out
=
pt_extra_ops
.
bartlett
(
c
)
fgraph
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
([],
[
out
],
[
np
.
array
(
5
)])
compare_jax_and_py
(
fgraph
,
[
np
.
array
(
5
)])
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
...
@@ -79,8 +65,7 @@ def test_ravel_multi_index_dynamic_shape():
...
@@ -79,8 +65,7 @@ def test_ravel_multi_index_dynamic_shape():
x
=
tensor
(
shape
=
(
None
,),
dtype
=
int
)
x
=
tensor
(
shape
=
(
None
,),
dtype
=
int
)
y
=
tensor
(
shape
=
(
None
,),
dtype
=
int
)
y
=
tensor
(
shape
=
(
None
,),
dtype
=
int
)
out
=
pt_extra_ops
.
ravel_multi_index
((
x
,
y
),
(
3
,
4
))
out
=
pt_extra_ops
.
ravel_multi_index
((
x
,
y
),
(
3
,
4
))
fgraph
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
([],
[
out
],
[
x_test
,
y_test
])
compare_jax_and_py
(
fgraph
,
[
x_test
,
y_test
])
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
...
@@ -89,5 +74,4 @@ def test_unique_dynamic_shape():
...
@@ -89,5 +74,4 @@ def test_unique_dynamic_shape():
a_test
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
a_test
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
out
=
pt_extra_ops
.
Unique
()(
a
)
out
=
pt_extra_ops
.
Unique
()(
a
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
tests/link/jax/test_math.py
浏览文件 @
cc8c4992
...
@@ -2,8 +2,6 @@ import numpy as np
...
@@ -2,8 +2,6 @@ import numpy as np
import
pytest
import
pytest
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.tensor.math
import
Argmax
,
Max
,
maximum
from
pytensor.tensor.math
import
Argmax
,
Max
,
maximum
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.type
import
dvector
,
matrix
,
scalar
,
vector
from
pytensor.tensor.type
import
dvector
,
matrix
,
scalar
,
vector
...
@@ -20,33 +18,39 @@ def test_jax_max_and_argmax():
...
@@ -20,33 +18,39 @@ def test_jax_max_and_argmax():
mx
=
Max
([
0
])(
x
)
mx
=
Max
([
0
])(
x
)
amx
=
Argmax
([
0
])(
x
)
amx
=
Argmax
([
0
])(
x
)
out
=
mx
*
amx
out
=
mx
*
amx
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
r_
[
1
,
2
]])
compare_jax_and_py
(
out_fg
,
[
np
.
r_
[
1
,
2
]])
def
test_dot
():
def
test_dot
():
y
=
vector
(
"y"
)
y
=
vector
(
"y"
)
y
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
y
_
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
x
.
tag
.
test_value
=
np
.
r_
[
3.0
,
4.0
]
.
astype
(
config
.
floatX
)
x
_
test_value
=
np
.
r_
[
3.0
,
4.0
]
.
astype
(
config
.
floatX
)
A
=
matrix
(
"A"
)
A
=
matrix
(
"A"
)
A
.
tag
.
test_value
=
np
.
empty
((
2
,
2
),
dtype
=
config
.
floatX
)
A
_
test_value
=
np
.
empty
((
2
,
2
),
dtype
=
config
.
floatX
)
alpha
=
scalar
(
"alpha"
)
alpha
=
scalar
(
"alpha"
)
alpha
.
tag
.
test_value
=
np
.
array
(
3.0
,
dtype
=
config
.
floatX
)
alpha
_
test_value
=
np
.
array
(
3.0
,
dtype
=
config
.
floatX
)
beta
=
scalar
(
"beta"
)
beta
=
scalar
(
"beta"
)
beta
.
tag
.
test_value
=
np
.
array
(
5.0
,
dtype
=
config
.
floatX
)
beta
_
test_value
=
np
.
array
(
5.0
,
dtype
=
config
.
floatX
)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
# leave the expression alone.
out
=
y
.
dot
(
alpha
*
A
)
.
dot
(
x
)
+
beta
*
y
out
=
y
.
dot
(
alpha
*
A
)
.
dot
(
x
)
+
beta
*
y
fgraph
=
FunctionGraph
([
y
,
x
,
A
,
alpha
,
beta
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
[
y
,
x
,
A
,
alpha
,
beta
],
out
,
[
y_test_value
,
x_test_value
,
A_test_value
,
alpha_test_value
,
beta_test_value
,
],
)
out
=
maximum
(
y
,
x
)
out
=
maximum
(
y
,
x
)
fgraph
=
FunctionGraph
([
y
,
x
],
[
out
])
compare_jax_and_py
([
y
,
x
],
[
out
],
[
y_test_value
,
x_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
pt_max
(
y
)
out
=
pt_max
(
y
)
fgraph
=
FunctionGraph
([
y
],
[
out
])
compare_jax_and_py
([
y
],
[
out
],
[
y_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
tests/link/jax/test_nlinalg.py
浏览文件 @
cc8c4992
...
@@ -3,7 +3,6 @@ import pytest
...
@@ -3,7 +3,6 @@ import pytest
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor.type
import
matrix
from
pytensor.tensor.type
import
matrix
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -21,41 +20,34 @@ def test_jax_basic_multiout():
...
@@ -21,41 +20,34 @@ def test_jax_basic_multiout():
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
outs
=
pt_nlinalg
.
eig
(
x
)
outs
=
pt_nlinalg
.
eig
(
x
)
out_fg
=
FunctionGraph
([
x
],
outs
)
def
assert_fn
(
x
,
y
):
def
assert_fn
(
x
,
y
):
np
.
testing
.
assert_allclose
(
x
.
astype
(
config
.
floatX
),
y
,
rtol
=
1e-3
)
np
.
testing
.
assert_allclose
(
x
.
astype
(
config
.
floatX
),
y
,
rtol
=
1e-3
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
[
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
eigh
(
x
)
outs
=
pt_nlinalg
.
eigh
(
x
)
out_fg
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"full"
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"full"
)
out_fg
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"reduced"
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"reduced"
)
out_fg
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
svd
(
x
)
outs
=
pt_nlinalg
.
svd
(
x
)
out_fg
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
slogdet
(
x
)
outs
=
pt_nlinalg
.
slogdet
(
x
)
out_fg
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
def
test_pinv
():
def
test_pinv
():
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
x_inv
=
pt_nlinalg
.
pinv
(
x
)
x_inv
=
pt_nlinalg
.
pinv
(
x
)
fgraph
=
FunctionGraph
([
x
],
[
x_inv
])
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
compare_jax_and_py
(
fgraph
,
[
x_np
])
compare_jax_and_py
(
[
x
],
[
x_inv
]
,
[
x_np
])
def
test_pinv_hermitian
():
def
test_pinv_hermitian
():
...
@@ -94,8 +86,7 @@ def test_kron():
...
@@ -94,8 +86,7 @@ def test_kron():
y
=
matrix
(
"y"
)
y
=
matrix
(
"y"
)
z
=
pt_nlinalg
.
kron
(
x
,
y
)
z
=
pt_nlinalg
.
kron
(
x
,
y
)
fgraph
=
FunctionGraph
([
x
,
y
],
[
z
])
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
y_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
y_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
compare_jax_and_py
(
fgraph
,
[
x_np
,
y_np
])
compare_jax_and_py
(
[
x
,
y
],
[
z
]
,
[
x_np
,
y_np
])
tests/link/jax/test_pad.py
浏览文件 @
cc8c4992
...
@@ -3,7 +3,6 @@ import pytest
...
@@ -3,7 +3,6 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.graph
import
FunctionGraph
from
pytensor.tensor.pad
import
PadMode
from
pytensor.tensor.pad
import
PadMode
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -53,10 +52,10 @@ def test_jax_pad(mode: PadMode, kwargs):
...
@@ -53,10 +52,10 @@ def test_jax_pad(mode: PadMode, kwargs):
x
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
x
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
res
=
pt
.
pad
(
x_pt
,
mode
=
mode
,
pad_width
=
3
,
**
kwargs
)
res
=
pt
.
pad
(
x_pt
,
mode
=
mode
,
pad_width
=
3
,
**
kwargs
)
res_fg
=
FunctionGraph
([
x_pt
],
[
res
])
compare_jax_and_py
(
compare_jax_and_py
(
res_fg
,
[
x_pt
],
[
res
],
[
x
],
[
x
],
assert_fn
=
lambda
x
,
y
:
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
RTOL
,
atol
=
ATOL
),
assert_fn
=
lambda
x
,
y
:
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
RTOL
,
atol
=
ATOL
),
py_mode
=
"FAST_RUN"
,
py_mode
=
"FAST_RUN"
,
...
...
tests/link/jax/test_random.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/jax/test_scalar.py
浏览文件 @
cc8c4992
...
@@ -5,7 +5,6 @@ import pytensor.scalar.basic as ps
...
@@ -5,7 +5,6 @@ import pytensor.scalar.basic as ps
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.scalar.basic
import
Composite
from
pytensor.scalar.basic
import
Composite
from
pytensor.tensor
import
as_tensor
from
pytensor.tensor
import
as_tensor
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.elemwise
import
Elemwise
...
@@ -51,20 +50,19 @@ def test_second():
...
@@ -51,20 +50,19 @@ def test_second():
b
=
scalar
(
"b"
)
b
=
scalar
(
"b"
)
out
=
ps
.
second
(
a0
,
b
)
out
=
ps
.
second
(
a0
,
b
)
fgraph
=
FunctionGraph
([
a0
,
b
],
[
out
])
compare_jax_and_py
([
a0
,
b
],
[
out
],
[
10.0
,
5.0
])
compare_jax_and_py
(
fgraph
,
[
10.0
,
5.0
])
a1
=
vector
(
"a1"
)
a1
=
vector
(
"a1"
)
out
=
pt
.
second
(
a1
,
b
)
out
=
pt
.
second
(
a1
,
b
)
fgraph
=
FunctionGraph
([
a1
,
b
],
[
out
])
compare_jax_and_py
([
a1
,
b
],
[
out
],
[
np
.
zeros
([
5
],
dtype
=
config
.
floatX
),
5.0
])
compare_jax_and_py
(
fgraph
,
[
np
.
zeros
([
5
],
dtype
=
config
.
floatX
),
5.0
])
a2
=
matrix
(
"a2"
,
shape
=
(
1
,
None
),
dtype
=
"float64"
)
a2
=
matrix
(
"a2"
,
shape
=
(
1
,
None
),
dtype
=
"float64"
)
b2
=
matrix
(
"b2"
,
shape
=
(
None
,
1
),
dtype
=
"int32"
)
b2
=
matrix
(
"b2"
,
shape
=
(
None
,
1
),
dtype
=
"int32"
)
out
=
pt
.
second
(
a2
,
b2
)
out
=
pt
.
second
(
a2
,
b2
)
fgraph
=
FunctionGraph
([
a2
,
b2
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
fgraph
,
[
np
.
zeros
((
1
,
3
),
dtype
=
"float64"
),
np
.
ones
((
5
,
1
),
dtype
=
"int32"
)]
[
a2
,
b2
],
[
out
],
[
np
.
zeros
((
1
,
3
),
dtype
=
"float64"
),
np
.
ones
((
5
,
1
),
dtype
=
"int32"
)],
)
)
...
@@ -81,11 +79,10 @@ def test_second_constant_scalar():
...
@@ -81,11 +79,10 @@ def test_second_constant_scalar():
def
test_identity
():
def
test_identity
():
a
=
scalar
(
"a"
)
a
=
scalar
(
"a"
)
a
.
tag
.
test_value
=
10
a
_
test_value
=
10
out
=
ps
.
identity
(
a
)
out
=
ps
.
identity
(
a
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -109,13 +106,11 @@ def test_jax_Composite_singe_output(x, y, x_val, y_val):
...
@@ -109,13 +106,11 @@ def test_jax_Composite_singe_output(x, y, x_val, y_val):
out
=
comp_op
(
x
,
y
)
out
=
comp_op
(
x
,
y
)
out_fg
=
FunctionGraph
([
x
,
y
],
[
out
])
test_input_vals
=
[
test_input_vals
=
[
x_val
.
astype
(
config
.
floatX
),
x_val
.
astype
(
config
.
floatX
),
y_val
.
astype
(
config
.
floatX
),
y_val
.
astype
(
config
.
floatX
),
]
]
_
=
compare_jax_and_py
(
out_fg
,
test_input_vals
)
_
=
compare_jax_and_py
(
[
x
,
y
],
[
out
]
,
test_input_vals
)
def
test_jax_Composite_multi_output
():
def
test_jax_Composite_multi_output
():
...
@@ -124,32 +119,28 @@ def test_jax_Composite_multi_output():
...
@@ -124,32 +119,28 @@ def test_jax_Composite_multi_output():
x_s
=
ps
.
float64
(
"xs"
)
x_s
=
ps
.
float64
(
"xs"
)
outs
=
Elemwise
(
Composite
(
inputs
=
[
x_s
],
outputs
=
[
x_s
+
1
,
x_s
-
1
]))(
x
)
outs
=
Elemwise
(
Composite
(
inputs
=
[
x_s
],
outputs
=
[
x_s
+
1
,
x_s
-
1
]))(
x
)
fgraph
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
np
.
arange
(
10
,
dtype
=
config
.
floatX
)])
compare_jax_and_py
(
fgraph
,
[
np
.
arange
(
10
,
dtype
=
config
.
floatX
)])
def
test_erf
():
def
test_erf
():
x
=
scalar
(
"x"
)
x
=
scalar
(
"x"
)
out
=
erf
(
x
)
out
=
erf
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
1.0
])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
1.0
])
def
test_erfc
():
def
test_erfc
():
x
=
scalar
(
"x"
)
x
=
scalar
(
"x"
)
out
=
erfc
(
x
)
out
=
erfc
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
1.0
])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
1.0
])
def
test_erfinv
():
def
test_erfinv
():
x
=
scalar
(
"x"
)
x
=
scalar
(
"x"
)
out
=
erfinv
(
x
)
out
=
erfinv
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
0.95
])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
0.95
])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -166,8 +157,7 @@ def test_tfp_ops(op, test_values):
...
@@ -166,8 +157,7 @@ def test_tfp_ops(op, test_values):
inputs
=
[
as_tensor
(
test_value
)
.
type
()
for
test_value
in
test_values
]
inputs
=
[
as_tensor
(
test_value
)
.
type
()
for
test_value
in
test_values
]
output
=
op
(
*
inputs
)
output
=
op
(
*
inputs
)
fg
=
FunctionGraph
(
inputs
,
[
output
])
compare_jax_and_py
(
inputs
,
[
output
],
test_values
)
compare_jax_and_py
(
fg
,
test_values
)
def
test_betaincinv
():
def
test_betaincinv
():
...
@@ -175,9 +165,10 @@ def test_betaincinv():
...
@@ -175,9 +165,10 @@ def test_betaincinv():
b
=
vector
(
"b"
,
dtype
=
"float64"
)
b
=
vector
(
"b"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
out
=
betaincinv
(
a
,
b
,
x
)
out
=
betaincinv
(
a
,
b
,
x
)
fg
=
FunctionGraph
([
a
,
b
,
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
fg
,
[
a
,
b
,
x
],
[
out
],
[
[
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
5.5
,
7.0
]),
...
@@ -190,39 +181,40 @@ def test_gammaincinv():
...
@@ -190,39 +181,40 @@ def test_gammaincinv():
k
=
vector
(
"k"
,
dtype
=
"float64"
)
k
=
vector
(
"k"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
out
=
gammaincinv
(
k
,
x
)
out
=
gammaincinv
(
k
,
x
)
fg
=
FunctionGraph
([
k
,
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
0.25
,
0.7
])])
compare_jax_and_py
(
[
k
,
x
],
[
out
]
,
[
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
0.25
,
0.7
])])
def
test_gammainccinv
():
def
test_gammainccinv
():
k
=
vector
(
"k"
,
dtype
=
"float64"
)
k
=
vector
(
"k"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
out
=
gammainccinv
(
k
,
x
)
out
=
gammainccinv
(
k
,
x
)
fg
=
FunctionGraph
([
k
,
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
0.25
,
0.7
])])
compare_jax_and_py
(
[
k
,
x
],
[
out
]
,
[
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
0.25
,
0.7
])])
def
test_psi
():
def
test_psi
():
x
=
scalar
(
"x"
)
x
=
scalar
(
"x"
)
out
=
psi
(
x
)
out
=
psi
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
3.0
])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
3.0
])
def
test_tri_gamma
():
def
test_tri_gamma
():
x
=
vector
(
"x"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
out
=
tri_gamma
(
x
)
out
=
tri_gamma
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
np
.
array
([
3.0
,
5.0
])])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
np
.
array
([
3.0
,
5.0
])])
def
test_polygamma
():
def
test_polygamma
():
n
=
vector
(
"n"
,
dtype
=
"int32"
)
n
=
vector
(
"n"
,
dtype
=
"int32"
)
x
=
vector
(
"x"
,
dtype
=
"float32"
)
x
=
vector
(
"x"
,
dtype
=
"float32"
)
out
=
polygamma
(
n
,
x
)
out
=
polygamma
(
n
,
x
)
fg
=
FunctionGraph
([
n
,
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
fg
,
[
n
,
x
],
[
out
],
[
[
np
.
array
([
0
,
1
,
2
])
.
astype
(
"int32"
),
np
.
array
([
0
,
1
,
2
])
.
astype
(
"int32"
),
np
.
array
([
0.5
,
0.9
,
2.5
])
.
astype
(
"float32"
),
np
.
array
([
0.5
,
0.9
,
2.5
])
.
astype
(
"float32"
),
...
@@ -233,41 +225,34 @@ def test_polygamma():
...
@@ -233,41 +225,34 @@ def test_polygamma():
def
test_log1mexp
():
def
test_log1mexp
():
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
out
=
log1mexp
(
x
)
out
=
log1mexp
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[[
-
1.0
,
-
0.75
,
-
0.5
,
-
0.25
]])
compare_jax_and_py
(
[
x
],
[
out
]
,
[[
-
1.0
,
-
0.75
,
-
0.5
,
-
0.25
]])
def
test_nnet
():
def
test_nnet
():
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
x
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
x
_
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
out
=
sigmoid
(
x
)
out
=
sigmoid
(
x
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
x_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
softplus
(
x
)
out
=
softplus
(
x
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
x_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_jax_variadic_Scalar
():
def
test_jax_variadic_Scalar
():
mu
=
vector
(
"mu"
,
dtype
=
config
.
floatX
)
mu
=
vector
(
"mu"
,
dtype
=
config
.
floatX
)
mu
.
tag
.
test_value
=
np
.
r_
[
0.1
,
1.1
]
.
astype
(
config
.
floatX
)
mu
_
test_value
=
np
.
r_
[
0.1
,
1.1
]
.
astype
(
config
.
floatX
)
tau
=
vector
(
"tau"
,
dtype
=
config
.
floatX
)
tau
=
vector
(
"tau"
,
dtype
=
config
.
floatX
)
tau
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
tau
_
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
res
=
-
tau
*
mu
res
=
-
tau
*
mu
fgraph
=
FunctionGraph
([
mu
,
tau
],
[
res
])
compare_jax_and_py
([
mu
,
tau
],
[
res
],
[
mu_test_value
,
tau_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
res
=
-
tau
*
(
tau
-
mu
)
**
2
res
=
-
tau
*
(
tau
-
mu
)
**
2
fgraph
=
FunctionGraph
([
mu
,
tau
],
[
res
])
compare_jax_and_py
([
mu
,
tau
],
[
res
],
[
mu_test_value
,
tau_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_add_scalars
():
def
test_add_scalars
():
...
@@ -275,8 +260,7 @@ def test_add_scalars():
...
@@ -275,8 +260,7 @@ def test_add_scalars():
size
=
x
.
shape
[
0
]
+
x
.
shape
[
0
]
+
x
.
shape
[
1
]
size
=
x
.
shape
[
0
]
+
x
.
shape
[
0
]
+
x
.
shape
[
1
]
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
ones
((
2
,
3
))
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
out_fg
,
[
np
.
ones
((
2
,
3
))
.
astype
(
config
.
floatX
)])
def
test_mul_scalars
():
def
test_mul_scalars
():
...
@@ -284,8 +268,7 @@ def test_mul_scalars():
...
@@ -284,8 +268,7 @@ def test_mul_scalars():
size
=
x
.
shape
[
0
]
*
x
.
shape
[
0
]
*
x
.
shape
[
1
]
size
=
x
.
shape
[
0
]
*
x
.
shape
[
0
]
*
x
.
shape
[
1
]
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
ones
((
2
,
3
))
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
out_fg
,
[
np
.
ones
((
2
,
3
))
.
astype
(
config
.
floatX
)])
def
test_div_scalars
():
def
test_div_scalars
():
...
@@ -293,8 +276,7 @@ def test_div_scalars():
...
@@ -293,8 +276,7 @@ def test_div_scalars():
size
=
x
.
shape
[
0
]
//
x
.
shape
[
1
]
size
=
x
.
shape
[
0
]
//
x
.
shape
[
1
]
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
ones
((
12
,
3
))
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
out_fg
,
[
np
.
ones
((
12
,
3
))
.
astype
(
config
.
floatX
)])
def
test_mod_scalars
():
def
test_mod_scalars
():
...
@@ -302,39 +284,43 @@ def test_mod_scalars():
...
@@ -302,39 +284,43 @@ def test_mod_scalars():
size
=
x
.
shape
[
0
]
%
x
.
shape
[
1
]
size
=
x
.
shape
[
0
]
%
x
.
shape
[
1
]
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
ones
((
12
,
3
))
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
out_fg
,
[
np
.
ones
((
12
,
3
))
.
astype
(
config
.
floatX
)])
def
test_jax_multioutput
():
def
test_jax_multioutput
():
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
x
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
x
_
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
y
=
vector
(
"y"
)
y
=
vector
(
"y"
)
y
.
tag
.
test_value
=
np
.
r_
[
3.0
,
4.0
]
.
astype
(
config
.
floatX
)
y
_
test_value
=
np
.
r_
[
3.0
,
4.0
]
.
astype
(
config
.
floatX
)
w
=
cosh
(
x
**
2
+
y
/
3.0
)
w
=
cosh
(
x
**
2
+
y
/
3.0
)
v
=
cosh
(
x
/
3.0
+
y
**
2
)
v
=
cosh
(
x
/
3.0
+
y
**
2
)
fgraph
=
FunctionGraph
([
x
,
y
],
[
w
,
v
])
compare_jax_and_py
([
x
,
y
],
[
w
,
v
],
[
x_test_value
,
y_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_jax_logp
():
def
test_jax_logp
():
mu
=
vector
(
"mu"
)
mu
=
vector
(
"mu"
)
mu
.
tag
.
test_value
=
np
.
r_
[
0.0
,
0.0
]
.
astype
(
config
.
floatX
)
mu
_
test_value
=
np
.
r_
[
0.0
,
0.0
]
.
astype
(
config
.
floatX
)
tau
=
vector
(
"tau"
)
tau
=
vector
(
"tau"
)
tau
.
tag
.
test_value
=
np
.
r_
[
1.0
,
1.0
]
.
astype
(
config
.
floatX
)
tau
_
test_value
=
np
.
r_
[
1.0
,
1.0
]
.
astype
(
config
.
floatX
)
sigma
=
vector
(
"sigma"
)
sigma
=
vector
(
"sigma"
)
sigma
.
tag
.
test_value
=
(
1.0
/
get_test_value
(
tau
)
)
.
astype
(
config
.
floatX
)
sigma
_test_value
=
(
1.0
/
tau_test_value
)
.
astype
(
config
.
floatX
)
value
=
vector
(
"value"
)
value
=
vector
(
"value"
)
value
.
tag
.
test_value
=
np
.
r_
[
0.1
,
-
10
]
.
astype
(
config
.
floatX
)
value
_
test_value
=
np
.
r_
[
0.1
,
-
10
]
.
astype
(
config
.
floatX
)
logp
=
(
-
tau
*
(
value
-
mu
)
**
2
+
log
(
tau
/
np
.
pi
/
2.0
))
/
2.0
logp
=
(
-
tau
*
(
value
-
mu
)
**
2
+
log
(
tau
/
np
.
pi
/
2.0
))
/
2.0
conditions
=
[
sigma
>
0
]
conditions
=
[
sigma
>
0
]
alltrue
=
pt_all
([
pt_all
(
1
*
val
)
for
val
in
conditions
])
alltrue
=
pt_all
([
pt_all
(
1
*
val
)
for
val
in
conditions
])
normal_logp
=
pt
.
switch
(
alltrue
,
logp
,
-
np
.
inf
)
normal_logp
=
pt
.
switch
(
alltrue
,
logp
,
-
np
.
inf
)
fgraph
=
FunctionGraph
([
mu
,
tau
,
sigma
,
value
],
[
normal_logp
])
compare_jax_and_py
(
[
mu
,
tau
,
sigma
,
value
],
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
[
normal_logp
],
[
mu_test_value
,
tau_test_value
,
sigma_test_value
,
value_test_value
,
],
)
tests/link/jax/test_scan.py
浏览文件 @
cc8c4992
...
@@ -7,7 +7,6 @@ import pytensor.tensor as pt
...
@@ -7,7 +7,6 @@ import pytensor.tensor as pt
from
pytensor
import
function
,
shared
from
pytensor
import
function
,
shared
from
pytensor.compile
import
get_mode
from
pytensor.compile
import
get_mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.scan
import
until
from
pytensor.scan
import
until
from
pytensor.scan.basic
import
scan
from
pytensor.scan.basic
import
scan
from
pytensor.scan.op
import
Scan
from
pytensor.scan.op
import
Scan
...
@@ -30,9 +29,8 @@ def test_scan_sit_sot(view):
...
@@ -30,9 +29,8 @@ def test_scan_sit_sot(view):
)
)
if
view
:
if
view
:
xs
=
xs
[
view
]
xs
=
xs
[
view
]
fg
=
FunctionGraph
([
x0
],
[
xs
])
test_input_vals
=
[
np
.
e
]
test_input_vals
=
[
np
.
e
]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
],
[
xs
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
...
@@ -45,9 +43,8 @@ def test_scan_mit_sot(view):
...
@@ -45,9 +43,8 @@ def test_scan_mit_sot(view):
)
)
if
view
:
if
view
:
xs
=
xs
[
view
]
xs
=
xs
[
view
]
fg
=
FunctionGraph
([
x0
],
[
xs
])
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
)]
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
)]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
],
[
xs
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
@pytest.mark.parametrize
(
"view_x"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
@pytest.mark.parametrize
(
"view_x"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
...
@@ -72,9 +69,8 @@ def test_scan_multiple_mit_sot(view_x, view_y):
...
@@ -72,9 +69,8 @@ def test_scan_multiple_mit_sot(view_x, view_y):
if
view_y
:
if
view_y
:
ys
=
ys
[
view_y
]
ys
=
ys
[
view_y
]
fg
=
FunctionGraph
([
x0
,
y0
],
[
xs
,
ys
])
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
),
np
.
full
((
4
,),
np
.
pi
)]
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
),
np
.
full
((
4
,),
np
.
pi
)]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
,
y0
],
[
xs
,
ys
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
2
,),
slice
(
None
,
None
,
2
)])
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
2
,),
slice
(
None
,
None
,
2
)])
...
@@ -90,12 +86,11 @@ def test_scan_nit_sot(view):
...
@@ -90,12 +86,11 @@ def test_scan_nit_sot(view):
)
)
if
view
:
if
view
:
ys
=
ys
[
view
]
ys
=
ys
[
view
]
fg
=
FunctionGraph
([
xs
],
[
ys
])
test_input_vals
=
[
rng
.
normal
(
size
=
10
)]
test_input_vals
=
[
rng
.
normal
(
size
=
10
)]
# We need to remove pushout rewrites, or the whole scan would just be
# We need to remove pushout rewrites, or the whole scan would just be
# converted to an Elemwise on xs
# converted to an Elemwise on xs
jax_fn
,
_
=
compare_jax_and_py
(
jax_fn
,
_
=
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan_pushout"
)
[
xs
],
[
ys
]
,
test_input_vals
,
jax_mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan_pushout"
)
)
)
scan_nodes
=
[
scan_nodes
=
[
node
for
node
in
jax_fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
node
for
node
in
jax_fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
...
@@ -112,8 +107,7 @@ def test_scan_mit_mot():
...
@@ -112,8 +107,7 @@ def test_scan_mit_mot():
n_steps
=
10
,
n_steps
=
10
,
)
)
grads_wrt_xs
=
pt
.
grad
(
ys
.
sum
(),
wrt
=
xs
)
grads_wrt_xs
=
pt
.
grad
(
ys
.
sum
(),
wrt
=
xs
)
fg
=
FunctionGraph
([
xs
],
[
grads_wrt_xs
])
compare_jax_and_py
([
xs
],
[
grads_wrt_xs
],
[
np
.
arange
(
10
)])
compare_jax_and_py
(
fg
,
[
np
.
arange
(
10
)])
def
test_scan_update
():
def
test_scan_update
():
...
@@ -192,8 +186,7 @@ def test_scan_while():
...
@@ -192,8 +186,7 @@ def test_scan_while():
n_steps
=
100
,
n_steps
=
100
,
)
)
fg
=
FunctionGraph
([],
[
xs
])
compare_jax_and_py
([],
[
xs
],
[])
compare_jax_and_py
(
fg
,
[])
def
test_scan_SEIR
():
def
test_scan_SEIR
():
...
@@ -257,11 +250,6 @@ def test_scan_SEIR():
...
@@ -257,11 +250,6 @@ def test_scan_SEIR():
logp_c_all
.
name
=
"C_t_logp"
logp_c_all
.
name
=
"C_t_logp"
logp_d_all
.
name
=
"D_t_logp"
logp_d_all
.
name
=
"D_t_logp"
out_fg
=
FunctionGraph
(
[
at_C
,
at_D
,
st0
,
et0
,
it0
,
logp_c
,
logp_d
,
beta
,
gamma
,
delta
],
[
st
,
et
,
it
,
logp_c_all
,
logp_d_all
],
)
s0
,
e0
,
i0
=
100
,
50
,
25
s0
,
e0
,
i0
=
100
,
50
,
25
logp_c0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
logp_c0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
logp_d0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
logp_d0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
...
@@ -283,7 +271,12 @@ def test_scan_SEIR():
...
@@ -283,7 +271,12 @@ def test_scan_SEIR():
gamma_val
,
gamma_val
,
delta_val
,
delta_val
,
]
]
compare_jax_and_py
(
out_fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
at_C
,
at_D
,
st0
,
et0
,
it0
,
logp_c
,
logp_d
,
beta
,
gamma
,
delta
],
[
st
,
et
,
it
,
logp_c_all
,
logp_d_all
],
test_input_vals
,
jax_mode
=
"JAX"
,
)
def
test_scan_mitsot_with_nonseq
():
def
test_scan_mitsot_with_nonseq
():
...
@@ -313,10 +306,8 @@ def test_scan_mitsot_with_nonseq():
...
@@ -313,10 +306,8 @@ def test_scan_mitsot_with_nonseq():
y_scan_pt
.
name
=
"y"
y_scan_pt
.
name
=
"y"
y_scan_pt
.
owner
.
inputs
[
0
]
.
name
=
"y_all"
y_scan_pt
.
owner
.
inputs
[
0
]
.
name
=
"y_all"
out_fg
=
FunctionGraph
([
a_pt
],
[
y_scan_pt
])
test_input_vals
=
[
np
.
array
(
10.0
)
.
astype
(
config
.
floatX
)]
test_input_vals
=
[
np
.
array
(
10.0
)
.
astype
(
config
.
floatX
)]
compare_jax_and_py
(
out_fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
a_pt
],
[
y_scan_pt
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
@pytest.mark.parametrize
(
"x0_func"
,
[
dvector
,
dmatrix
])
@pytest.mark.parametrize
(
"x0_func"
,
[
dvector
,
dmatrix
])
...
@@ -343,9 +334,8 @@ def test_nd_scan_sit_sot(x0_func, A_func):
...
@@ -343,9 +334,8 @@ def test_nd_scan_sit_sot(x0_func, A_func):
)
)
A_val
=
np
.
eye
(
k
,
dtype
=
config
.
floatX
)
A_val
=
np
.
eye
(
k
,
dtype
=
config
.
floatX
)
fg
=
FunctionGraph
([
x0
,
A
],
[
xs
])
test_input_vals
=
[
x0_val
,
A_val
]
test_input_vals
=
[
x0_val
,
A_val
]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
,
A
],
[
xs
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
def
test_nd_scan_sit_sot_with_seq
():
def
test_nd_scan_sit_sot_with_seq
():
...
@@ -366,9 +356,8 @@ def test_nd_scan_sit_sot_with_seq():
...
@@ -366,9 +356,8 @@ def test_nd_scan_sit_sot_with_seq():
x_val
=
np
.
arange
(
n_steps
*
k
,
dtype
=
config
.
floatX
)
.
reshape
(
n_steps
,
k
)
x_val
=
np
.
arange
(
n_steps
*
k
,
dtype
=
config
.
floatX
)
.
reshape
(
n_steps
,
k
)
A_val
=
np
.
eye
(
k
,
dtype
=
config
.
floatX
)
A_val
=
np
.
eye
(
k
,
dtype
=
config
.
floatX
)
fg
=
FunctionGraph
([
x
,
A
],
[
xs
])
test_input_vals
=
[
x_val
,
A_val
]
test_input_vals
=
[
x_val
,
A_val
]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x
,
A
],
[
xs
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
def
test_nd_scan_mit_sot
():
def
test_nd_scan_mit_sot
():
...
@@ -384,13 +373,12 @@ def test_nd_scan_mit_sot():
...
@@ -384,13 +373,12 @@ def test_nd_scan_mit_sot():
n_steps
=
10
,
n_steps
=
10
,
)
)
fg
=
FunctionGraph
([
x0
,
A
,
B
],
[
xs
])
x0_val
=
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
)
x0_val
=
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
)
A_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
A_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
B_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
B_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
test_input_vals
=
[
x0_val
,
A_val
,
B_val
]
test_input_vals
=
[
x0_val
,
A_val
,
B_val
]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
,
A
,
B
],
[
xs
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
def
test_nd_scan_sit_sot_with_carry
():
def
test_nd_scan_sit_sot_with_carry
():
...
@@ -409,12 +397,11 @@ def test_nd_scan_sit_sot_with_carry():
...
@@ -409,12 +397,11 @@ def test_nd_scan_sit_sot_with_carry():
mode
=
get_mode
(
"JAX"
),
mode
=
get_mode
(
"JAX"
),
)
)
fg
=
FunctionGraph
([
x0
,
A
],
xs
)
x0_val
=
np
.
arange
(
3
,
dtype
=
config
.
floatX
)
x0_val
=
np
.
arange
(
3
,
dtype
=
config
.
floatX
)
A_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
A_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
test_input_vals
=
[
x0_val
,
A_val
]
test_input_vals
=
[
x0_val
,
A_val
]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
,
A
],
xs
,
test_input_vals
,
jax_mode
=
"JAX"
)
def
test_default_mode_excludes_incompatible_rewrites
():
def
test_default_mode_excludes_incompatible_rewrites
():
...
@@ -422,8 +409,7 @@ def test_default_mode_excludes_incompatible_rewrites():
...
@@ -422,8 +409,7 @@ def test_default_mode_excludes_incompatible_rewrites():
A
=
matrix
(
"A"
)
A
=
matrix
(
"A"
)
B
=
matrix
(
"B"
)
B
=
matrix
(
"B"
)
out
,
_
=
scan
(
lambda
a
,
b
:
a
@
b
,
outputs_info
=
[
A
],
non_sequences
=
[
B
],
n_steps
=
2
)
out
,
_
=
scan
(
lambda
a
,
b
:
a
@
b
,
outputs_info
=
[
A
],
non_sequences
=
[
B
],
n_steps
=
2
)
fg
=
FunctionGraph
([
A
,
B
],
[
out
])
compare_jax_and_py
([
A
,
B
],
[
out
],
[
np
.
eye
(
3
),
np
.
eye
(
3
)],
jax_mode
=
"JAX"
)
compare_jax_and_py
(
fg
,
[
np
.
eye
(
3
),
np
.
eye
(
3
)],
jax_mode
=
"JAX"
)
def
test_dynamic_sequence_length
():
def
test_dynamic_sequence_length
():
...
...
tests/link/jax/test_shape.py
浏览文件 @
cc8c4992
...
@@ -4,7 +4,6 @@ import pytest
...
@@ -4,7 +4,6 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.compile.ops
import
DeepCopyOp
,
ViewOp
from
pytensor.compile.ops
import
DeepCopyOp
,
ViewOp
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor.shape
import
Shape
,
Shape_i
,
Unbroadcast
,
reshape
from
pytensor.tensor.shape
import
Shape
,
Shape_i
,
Unbroadcast
,
reshape
from
pytensor.tensor.type
import
iscalar
,
vector
from
pytensor.tensor.type
import
iscalar
,
vector
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -13,29 +12,27 @@ from tests.link.jax.test_basic import compare_jax_and_py
...
@@ -13,29 +12,27 @@ from tests.link.jax.test_basic import compare_jax_and_py
def
test_jax_shape_ops
():
def
test_jax_shape_ops
():
x_np
=
np
.
zeros
((
20
,
3
))
x_np
=
np
.
zeros
((
20
,
3
))
x
=
Shape
()(
pt
.
as_tensor_variable
(
x_np
))
x
=
Shape
()(
pt
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[],
must_be_device_array
=
False
)
compare_jax_and_py
(
[],
[
x
]
,
[],
must_be_device_array
=
False
)
x
=
Shape_i
(
1
)(
pt
.
as_tensor_variable
(
x_np
))
x
=
Shape_i
(
1
)(
pt
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[],
must_be_device_array
=
False
)
compare_jax_and_py
(
[],
[
x
]
,
[],
must_be_device_array
=
False
)
def
test_jax_specify_shape
():
def
test_jax_specify_shape
():
in_pt
=
pt
.
matrix
(
"in"
)
in_pt
=
pt
.
matrix
(
"in"
)
x
=
pt
.
specify_shape
(
in_pt
,
(
4
,
None
))
x
=
pt
.
specify_shape
(
in_pt
,
(
4
,
None
))
x_fg
=
FunctionGraph
([
in_pt
],
[
x
])
compare_jax_and_py
([
in_pt
],
[
x
],
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)])
# When used to assert two arrays have similar shapes
# When used to assert two arrays have similar shapes
in_pt
=
pt
.
matrix
(
"in"
)
in_pt
=
pt
.
matrix
(
"in"
)
shape_pt
=
pt
.
matrix
(
"shape"
)
shape_pt
=
pt
.
matrix
(
"shape"
)
x
=
pt
.
specify_shape
(
in_pt
,
shape_pt
.
shape
)
x
=
pt
.
specify_shape
(
in_pt
,
shape_pt
.
shape
)
x_fg
=
FunctionGraph
([
in_pt
,
shape_pt
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
in_pt
,
shape_pt
],
[
x
],
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
),
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)],
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
),
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)],
)
)
...
@@ -43,20 +40,17 @@ def test_jax_specify_shape():
...
@@ -43,20 +40,17 @@ def test_jax_specify_shape():
def
test_jax_Reshape_constant
():
def
test_jax_Reshape_constant
():
a
=
vector
(
"a"
)
a
=
vector
(
"a"
)
x
=
reshape
(
a
,
(
2
,
2
))
x
=
reshape
(
a
,
(
2
,
2
))
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
([
a
],
[
x
],
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
def
test_jax_Reshape_concrete_shape
():
def
test_jax_Reshape_concrete_shape
():
"""JAX should compile when a concrete value is passed for the `shape` parameter."""
"""JAX should compile when a concrete value is passed for the `shape` parameter."""
a
=
vector
(
"a"
)
a
=
vector
(
"a"
)
x
=
reshape
(
a
,
a
.
shape
)
x
=
reshape
(
a
,
a
.
shape
)
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
([
a
],
[
x
],
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
x
=
reshape
(
a
,
(
a
.
shape
[
0
]
//
2
,
a
.
shape
[
0
]
//
2
))
x
=
reshape
(
a
,
(
a
.
shape
[
0
]
//
2
,
a
.
shape
[
0
]
//
2
))
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
([
a
],
[
x
],
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
@pytest.mark.xfail
(
@pytest.mark.xfail
(
...
@@ -66,23 +60,20 @@ def test_jax_Reshape_shape_graph_input():
...
@@ -66,23 +60,20 @@ def test_jax_Reshape_shape_graph_input():
a
=
vector
(
"a"
)
a
=
vector
(
"a"
)
shape_pt
=
iscalar
(
"b"
)
shape_pt
=
iscalar
(
"b"
)
x
=
reshape
(
a
,
(
shape_pt
,
shape_pt
))
x
=
reshape
(
a
,
(
shape_pt
,
shape_pt
))
x_fg
=
FunctionGraph
([
a
,
shape_pt
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
])
[
a
,
shape_pt
],
[
x
],
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
]
)
def
test_jax_compile_ops
():
def
test_jax_compile_ops
():
x
=
DeepCopyOp
()(
pt
.
as_tensor_variable
(
1.1
))
x
=
DeepCopyOp
()(
pt
.
as_tensor_variable
(
1.1
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
([],
[
x
],
[])
compare_jax_and_py
(
x_fg
,
[])
x_np
=
np
.
zeros
((
20
,
1
,
1
))
x_np
=
np
.
zeros
((
20
,
1
,
1
))
x
=
Unbroadcast
(
0
,
2
)(
pt
.
as_tensor_variable
(
x_np
))
x
=
Unbroadcast
(
0
,
2
)(
pt
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
[],
[
x
]
,
[])
x
=
ViewOp
()(
pt
.
as_tensor_variable
(
x_np
))
x
=
ViewOp
()(
pt
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
[],
[
x
]
,
[])
tests/link/jax/test_slinalg.py
浏览文件 @
cc8c4992
...
@@ -6,7 +6,6 @@ import pytest
...
@@ -6,7 +6,6 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor
import
slinalg
as
pt_slinalg
from
pytensor.tensor
import
slinalg
as
pt_slinalg
from
pytensor.tensor
import
subtensor
as
pt_subtensor
from
pytensor.tensor
import
subtensor
as
pt_subtensor
...
@@ -30,13 +29,11 @@ def test_jax_basic():
...
@@ -30,13 +29,11 @@ def test_jax_basic():
out
=
pt_subtensor
.
inc_subtensor
(
out
[
0
,
1
],
2.0
)
out
=
pt_subtensor
.
inc_subtensor
(
out
[
0
,
1
],
2.0
)
out
=
out
[:
5
,
:
3
]
out
=
out
[:
5
,
:
3
]
out_fg
=
FunctionGraph
([
x
,
y
],
[
out
])
test_input_vals
=
[
test_input_vals
=
[
np
.
tile
(
np
.
arange
(
10
),
(
10
,
1
))
.
astype
(
config
.
floatX
),
np
.
tile
(
np
.
arange
(
10
),
(
10
,
1
))
.
astype
(
config
.
floatX
),
np
.
tile
(
np
.
arange
(
10
,
20
),
(
10
,
1
))
.
astype
(
config
.
floatX
),
np
.
tile
(
np
.
arange
(
10
,
20
),
(
10
,
1
))
.
astype
(
config
.
floatX
),
]
]
_
,
[
jax_res
]
=
compare_jax_and_py
(
out_fg
,
test_input_vals
)
_
,
[
jax_res
]
=
compare_jax_and_py
(
[
x
,
y
],
[
out
]
,
test_input_vals
)
# Confirm that the `Subtensor` slice operations are correct
# Confirm that the `Subtensor` slice operations are correct
assert
jax_res
.
shape
==
(
5
,
3
)
assert
jax_res
.
shape
==
(
5
,
3
)
...
@@ -46,19 +43,17 @@ def test_jax_basic():
...
@@ -46,19 +43,17 @@ def test_jax_basic():
assert
jax_res
[
0
,
1
]
==
-
8.0
assert
jax_res
[
0
,
1
]
==
-
8.0
out
=
clip
(
x
,
y
,
5
)
out
=
clip
(
x
,
y
,
5
)
out_fg
=
FunctionGraph
([
x
,
y
],
[
out
])
compare_jax_and_py
([
x
,
y
],
[
out
],
test_input_vals
)
compare_jax_and_py
(
out_fg
,
test_input_vals
)
out
=
pt
.
diagonal
(
x
,
0
)
out
=
pt
.
diagonal
(
x
,
0
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
config
.
floatX
)]
[
x
],
[
out
]
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
config
.
floatX
)]
)
)
out
=
pt_slinalg
.
cholesky
(
x
)
out
=
pt_slinalg
.
cholesky
(
x
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
x
],
[
out
],
[
[
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
config
.
floatX
config
.
floatX
...
@@ -68,9 +63,9 @@ def test_jax_basic():
...
@@ -68,9 +63,9 @@ def test_jax_basic():
# not sure why this isn't working yet with lower=False
# not sure why this isn't working yet with lower=False
out
=
pt_slinalg
.
Cholesky
(
lower
=
False
)(
x
)
out
=
pt_slinalg
.
Cholesky
(
lower
=
False
)(
x
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
x
],
[
out
],
[
[
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
config
.
floatX
config
.
floatX
...
@@ -79,9 +74,9 @@ def test_jax_basic():
...
@@ -79,9 +74,9 @@ def test_jax_basic():
)
)
out
=
pt_slinalg
.
solve
(
x
,
b
)
out
=
pt_slinalg
.
solve
(
x
,
b
)
out_fg
=
FunctionGraph
([
x
,
b
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
x
,
b
],
[
out
],
[
[
np
.
eye
(
10
)
.
astype
(
config
.
floatX
),
np
.
eye
(
10
)
.
astype
(
config
.
floatX
),
np
.
arange
(
10
)
.
astype
(
config
.
floatX
),
np
.
arange
(
10
)
.
astype
(
config
.
floatX
),
...
@@ -89,19 +84,17 @@ def test_jax_basic():
...
@@ -89,19 +84,17 @@ def test_jax_basic():
)
)
out
=
pt
.
diag
(
b
)
out
=
pt
.
diag
(
b
)
out_fg
=
FunctionGraph
([
b
],
[
out
])
compare_jax_and_py
([
b
],
[
out
],
[
np
.
arange
(
10
)
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
10
)
.
astype
(
config
.
floatX
)])
out
=
pt_nlinalg
.
det
(
x
)
out
=
pt_nlinalg
.
det
(
x
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
config
.
floatX
)]
[
x
],
[
out
]
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
config
.
floatX
)]
)
)
out
=
pt_nlinalg
.
matrix_inverse
(
x
)
out
=
pt_nlinalg
.
matrix_inverse
(
x
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
x
],
[
out
],
[
[
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
config
.
floatX
config
.
floatX
...
@@ -124,9 +117,9 @@ def test_jax_SolveTriangular(trans, lower, check_finite):
...
@@ -124,9 +117,9 @@ def test_jax_SolveTriangular(trans, lower, check_finite):
lower
=
lower
,
lower
=
lower
,
check_finite
=
check_finite
,
check_finite
=
check_finite
,
)
)
out_fg
=
FunctionGraph
([
x
,
b
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
x
,
b
],
[
out
],
[
[
np
.
eye
(
10
)
.
astype
(
config
.
floatX
),
np
.
eye
(
10
)
.
astype
(
config
.
floatX
),
np
.
arange
(
10
)
.
astype
(
config
.
floatX
),
np
.
arange
(
10
)
.
astype
(
config
.
floatX
),
...
@@ -141,10 +134,10 @@ def test_jax_block_diag():
...
@@ -141,10 +134,10 @@ def test_jax_block_diag():
D
=
matrix
(
"D"
)
D
=
matrix
(
"D"
)
out
=
pt_slinalg
.
block_diag
(
A
,
B
,
C
,
D
)
out
=
pt_slinalg
.
block_diag
(
A
,
B
,
C
,
D
)
out_fg
=
FunctionGraph
([
A
,
B
,
C
,
D
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
A
,
B
,
C
,
D
],
[
out
],
[
[
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
),
...
@@ -158,9 +151,10 @@ def test_jax_block_diag_blockwise():
...
@@ -158,9 +151,10 @@ def test_jax_block_diag_blockwise():
A
=
pt
.
tensor3
(
"A"
)
A
=
pt
.
tensor3
(
"A"
)
B
=
pt
.
tensor3
(
"B"
)
B
=
pt
.
tensor3
(
"B"
)
out
=
pt_slinalg
.
block_diag
(
A
,
B
)
out
=
pt_slinalg
.
block_diag
(
A
,
B
)
out_fg
=
FunctionGraph
([
A
,
B
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
A
,
B
],
[
out
],
[
[
np
.
random
.
normal
(
size
=
(
5
,
5
,
5
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
5
,
5
,
5
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
5
,
3
,
3
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
5
,
3
,
3
))
.
astype
(
config
.
floatX
),
...
@@ -174,11 +168,11 @@ def test_jax_eigvalsh(lower):
...
@@ -174,11 +168,11 @@ def test_jax_eigvalsh(lower):
B
=
matrix
(
"B"
)
B
=
matrix
(
"B"
)
out
=
pt_slinalg
.
eigvalsh
(
A
,
B
,
lower
=
lower
)
out
=
pt_slinalg
.
eigvalsh
(
A
,
B
,
lower
=
lower
)
out_fg
=
FunctionGraph
([
A
,
B
],
[
out
])
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
A
,
B
],
[
out
],
[
[
np
.
array
(
np
.
array
(
[[
6
,
3
,
1
,
5
],
[
3
,
0
,
5
,
1
],
[
1
,
5
,
6
,
2
],
[
5
,
1
,
2
,
2
]]
[[
6
,
3
,
1
,
5
],
[
3
,
0
,
5
,
1
],
[
1
,
5
,
6
,
2
],
[
5
,
1
,
2
,
2
]]
...
@@ -189,7 +183,8 @@ def test_jax_eigvalsh(lower):
...
@@ -189,7 +183,8 @@ def test_jax_eigvalsh(lower):
],
],
)
)
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
A
,
B
],
[
out
],
[
[
np
.
array
([[
6
,
3
,
1
,
5
],
[
3
,
0
,
5
,
1
],
[
1
,
5
,
6
,
2
],
[
5
,
1
,
2
,
2
]])
.
astype
(
np
.
array
([[
6
,
3
,
1
,
5
],
[
3
,
0
,
5
,
1
],
[
1
,
5
,
6
,
2
],
[
5
,
1
,
2
,
2
]])
.
astype
(
config
.
floatX
config
.
floatX
...
@@ -207,11 +202,11 @@ def test_jax_solve_discrete_lyapunov(
...
@@ -207,11 +202,11 @@ def test_jax_solve_discrete_lyapunov(
A
=
pt
.
tensor
(
name
=
"A"
,
shape
=
shape
)
A
=
pt
.
tensor
(
name
=
"A"
,
shape
=
shape
)
B
=
pt
.
tensor
(
name
=
"B"
,
shape
=
shape
)
B
=
pt
.
tensor
(
name
=
"B"
,
shape
=
shape
)
out
=
pt_slinalg
.
solve_discrete_lyapunov
(
A
,
B
,
method
=
method
)
out
=
pt_slinalg
.
solve_discrete_lyapunov
(
A
,
B
,
method
=
method
)
out_fg
=
FunctionGraph
([
A
,
B
],
[
out
])
atol
=
rtol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-3
atol
=
rtol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-3
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
A
,
B
],
[
out
],
[
[
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
),
...
...
tests/link/jax/test_sort.py
浏览文件 @
cc8c4992
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
pytensor.graph
import
FunctionGraph
from
pytensor.tensor
import
matrix
from
pytensor.tensor
import
matrix
from
pytensor.tensor.sort
import
argsort
,
sort
from
pytensor.tensor.sort
import
argsort
,
sort
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -12,6 +11,5 @@ from tests.link.jax.test_basic import compare_jax_and_py
...
@@ -12,6 +11,5 @@ from tests.link.jax.test_basic import compare_jax_and_py
def
test_sort
(
func
,
axis
):
def
test_sort
(
func
,
axis
):
x
=
matrix
(
"x"
,
shape
=
(
2
,
2
),
dtype
=
"float64"
)
x
=
matrix
(
"x"
,
shape
=
(
2
,
2
),
dtype
=
"float64"
)
out
=
func
(
x
,
axis
=
axis
)
out
=
func
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
arr
=
np
.
array
([[
1.0
,
4.0
],
[
5.0
,
2.0
]])
arr
=
np
.
array
([[
1.0
,
4.0
],
[
5.0
,
2.0
]])
compare_jax_and_py
(
fgraph
,
[
arr
])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
arr
])
tests/link/jax/test_sparse.py
浏览文件 @
cc8c4992
...
@@ -5,7 +5,6 @@ import scipy.sparse
...
@@ -5,7 +5,6 @@ import scipy.sparse
import
pytensor.sparse
as
ps
import
pytensor.sparse
as
ps
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
function
from
pytensor
import
function
from
pytensor.graph
import
FunctionGraph
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -50,8 +49,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op):
...
@@ -50,8 +49,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op):
test_values
.
append
(
y_test
)
test_values
.
append
(
y_test
)
dot_pt
=
op
(
x_pt
,
y_pt
)
dot_pt
=
op
(
x_pt
,
y_pt
)
fgraph
=
FunctionGraph
(
inputs
,
[
dot_pt
])
compare_jax_and_py
(
inputs
,
[
dot_pt
],
test_values
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
fgraph
,
test_values
,
jax_mode
=
"JAX"
)
def
test_sparse_dot_non_const_raises
():
def
test_sparse_dot_non_const_raises
():
...
...
tests/link/jax/test_subtensor.py
浏览文件 @
cc8c4992
...
@@ -21,55 +21,55 @@ def test_jax_Subtensor_constant():
...
@@ -21,55 +21,55 @@ def test_jax_Subtensor_constant():
# Basic indices
# Basic indices
out_pt
=
x_pt
[
1
,
2
,
0
]
out_pt
=
x_pt
[
1
,
2
,
0
]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
out_pt
=
x_pt
[
1
:,
1
,
:]
out_pt
=
x_pt
[
1
:,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
out_pt
=
x_pt
[:
2
,
1
,
:]
out_pt
=
x_pt
[:
2
,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
out_pt
=
x_pt
[
1
:
2
,
1
,
:]
out_pt
=
x_pt
[
1
:
2
,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
# Advanced indexing
# Advanced indexing
out_pt
=
pt_subtensor
.
advanced_subtensor1
(
x_pt
,
[
1
,
2
])
out_pt
=
pt_subtensor
.
advanced_subtensor1
(
x_pt
,
[
1
,
2
])
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor1
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor1
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
[
2
,
3
]]
out_pt
=
x_pt
[[
1
,
2
],
[
2
,
3
]]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
# Advanced and basic indexing
# Advanced and basic indexing
out_pt
=
x_pt
[[
1
,
2
],
:]
out_pt
=
x_pt
[[
1
,
2
],
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
:,
[
3
,
4
]]
out_pt
=
x_pt
[[
1
,
2
],
:,
[
3
,
4
]]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
# Flipping
# Flipping
out_pt
=
x_pt
[::
-
1
]
out_pt
=
x_pt
[::
-
1
]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
# Boolean indexing should work if indexes are constant
# Boolean indexing should work if indexes are constant
out_pt
=
x_pt
[
np
.
random
.
binomial
(
1
,
0.5
,
size
=
(
3
,
4
,
5
))
.
astype
(
bool
)]
out_pt
=
x_pt
[
np
.
random
.
binomial
(
1
,
0.5
,
size
=
(
3
,
4
,
5
))
.
astype
(
bool
)]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
@pytest.mark.xfail
(
reason
=
"`a` should be specified as static when JIT-compiling"
)
@pytest.mark.xfail
(
reason
=
"`a` should be specified as static when JIT-compiling"
)
...
@@ -78,8 +78,8 @@ def test_jax_Subtensor_dynamic():
...
@@ -78,8 +78,8 @@ def test_jax_Subtensor_dynamic():
x
=
pt
.
arange
(
3
)
x
=
pt
.
arange
(
3
)
out_pt
=
x
[:
a
]
out_pt
=
x
[:
a
]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
a
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
1
])
compare_jax_and_py
(
[
a
],
[
out_pt
]
,
[
1
])
def
test_jax_Subtensor_dynamic_boolean_mask
():
def
test_jax_Subtensor_dynamic_boolean_mask
():
...
@@ -90,11 +90,9 @@ def test_jax_Subtensor_dynamic_boolean_mask():
...
@@ -90,11 +90,9 @@ def test_jax_Subtensor_dynamic_boolean_mask():
out_pt
=
x_pt
[
x_pt
<
0
]
out_pt
=
x_pt
[
x_pt
<
0
]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
x_pt_test
=
np
.
arange
(
-
5
,
5
)
x_pt_test
=
np
.
arange
(
-
5
,
5
)
with
pytest
.
raises
(
NonConcreteBooleanIndexError
):
with
pytest
.
raises
(
NonConcreteBooleanIndexError
):
compare_jax_and_py
(
out_fg
,
[
x_pt_test
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_pt_test
])
def
test_jax_Subtensor_boolean_mask_reexpressible
():
def
test_jax_Subtensor_boolean_mask_reexpressible
():
...
@@ -110,8 +108,10 @@ def test_jax_Subtensor_boolean_mask_reexpressible():
...
@@ -110,8 +108,10 @@ def test_jax_Subtensor_boolean_mask_reexpressible():
"""
"""
x_pt
=
pt
.
matrix
(
"x"
)
x_pt
=
pt
.
matrix
(
"x"
)
out_pt
=
x_pt
[
x_pt
<
0
]
.
sum
()
out_pt
=
x_pt
[
x_pt
<
0
]
.
sum
()
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
25
)
.
reshape
(
5
,
5
)
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
],
[
np
.
arange
(
25
)
.
reshape
(
5
,
5
)
.
astype
(
config
.
floatX
)]
)
def
test_boolean_indexing_sum_not_applicable
():
def
test_boolean_indexing_sum_not_applicable
():
...
@@ -136,19 +136,19 @@ def test_jax_IncSubtensor():
...
@@ -136,19 +136,19 @@ def test_jax_IncSubtensor():
st_pt
=
pt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
config
.
floatX
))
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
1
,
2
,
3
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
1
,
2
,
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[:
2
,
0
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[:
2
,
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
0
,
1
:
3
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
0
,
1
:
3
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
# "Set" advanced indices
# "Set" advanced indices
st_pt
=
pt
.
as_tensor_variable
(
st_pt
=
pt
.
as_tensor_variable
(
...
@@ -156,39 +156,39 @@ def test_jax_IncSubtensor():
...
@@ -156,39 +156,39 @@ def test_jax_IncSubtensor():
)
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
np
.
r_
[
0
,
2
]],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
np
.
r_
[
0
,
2
]],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[[
0
,
2
],
0
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[[
0
,
2
],
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
# "Set" boolean indices
# "Set" boolean indices
mask_pt
=
pt
.
constant
(
x_np
>
0
)
mask_pt
=
pt
.
constant
(
x_np
>
0
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
0.0
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
0.0
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
# "Increment" basic indices
# "Increment" basic indices
st_pt
=
pt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
config
.
floatX
))
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
1
,
2
,
3
],
st_pt
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
1
,
2
,
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[:
2
,
0
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[:
2
,
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
0
,
1
:
3
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
0
,
1
:
3
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
# "Increment" advanced indices
# "Increment" advanced indices
st_pt
=
pt
.
as_tensor_variable
(
st_pt
=
pt
.
as_tensor_variable
(
...
@@ -196,33 +196,33 @@ def test_jax_IncSubtensor():
...
@@ -196,33 +196,33 @@ def test_jax_IncSubtensor():
)
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
np
.
r_
[
0
,
2
]],
st_pt
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
np
.
r_
[
0
,
2
]],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[[
0
,
2
],
0
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[[
0
,
2
],
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
# "Increment" boolean indices
# "Increment" boolean indices
mask_pt
=
pt
.
constant
(
x_np
>
0
)
mask_pt
=
pt
.
constant
(
x_np
>
0
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
1.0
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
1.0
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
st_pt
=
pt
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[[
0
,
2
],
0
,
:
3
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[[
0
,
2
],
0
,
:
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
st_pt
=
pt
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[[
0
,
2
],
0
,
:
3
],
st_pt
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[[
0
,
2
],
0
,
:
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
def
test_jax_IncSubtensor_boolean_indexing_reexpressible
():
def
test_jax_IncSubtensor_boolean_indexing_reexpressible
():
...
@@ -243,14 +243,14 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
...
@@ -243,14 +243,14 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
mask_pt
=
pt
.
as_tensor
(
x_pt
)
>
0
mask_pt
=
pt
.
as_tensor
(
x_pt
)
>
0
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
0.0
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
0.0
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
mask_pt
=
pt
.
as_tensor
(
x_pt
)
>
0
mask_pt
=
pt
.
as_tensor
(
x_pt
)
>
0
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
mask_pt
],
1.0
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
mask_pt
],
1.0
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
def
test_boolean_indexing_set_or_inc_not_applicable
():
def
test_boolean_indexing_set_or_inc_not_applicable
():
...
...
tests/link/jax/test_tensor_basic.py
浏览文件 @
cc8c4992
...
@@ -10,8 +10,6 @@ from jax import errors
...
@@ -10,8 +10,6 @@ from jax import errors
import
pytensor
import
pytensor
import
pytensor.tensor.basic
as
ptb
import
pytensor.tensor.basic
as
ptb
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.tensor.type
import
iscalar
,
matrix
,
scalar
,
vector
from
pytensor.tensor.type
import
iscalar
,
matrix
,
scalar
,
vector
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.tensor.test_basic
import
check_alloc_runtime_broadcast
from
tests.tensor.test_basic
import
check_alloc_runtime_broadcast
...
@@ -19,38 +17,31 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast
...
@@ -19,38 +17,31 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast
def
test_jax_Alloc
():
def
test_jax_Alloc
():
x
=
ptb
.
alloc
(
0.0
,
2
,
3
)
x
=
ptb
.
alloc
(
0.0
,
2
,
3
)
x_fg
=
FunctionGraph
([],
[
x
])
_
,
[
jax_res
]
=
compare_jax_and_py
(
x_fg
,
[])
_
,
[
jax_res
]
=
compare_jax_and_py
(
[],
[
x
]
,
[])
assert
jax_res
.
shape
==
(
2
,
3
)
assert
jax_res
.
shape
==
(
2
,
3
)
x
=
ptb
.
alloc
(
1.1
,
2
,
3
)
x
=
ptb
.
alloc
(
1.1
,
2
,
3
)
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
[],
[
x
]
,
[])
x
=
ptb
.
AllocEmpty
(
"float32"
)(
2
,
3
)
x
=
ptb
.
AllocEmpty
(
"float32"
)(
2
,
3
)
x_fg
=
FunctionGraph
([],
[
x
])
def
compare_shape_dtype
(
x
,
y
):
def
compare_shape_dtype
(
x
,
y
):
(
x
,)
=
x
np
.
testing
.
assert_array_equal
(
x
,
y
,
strict
=
True
)
(
y
,)
=
y
return
x
.
shape
==
y
.
shape
and
x
.
dtype
==
y
.
dtype
compare_jax_and_py
(
x_fg
,
[],
assert_fn
=
compare_shape_dtype
)
compare_jax_and_py
(
[],
[
x
]
,
[],
assert_fn
=
compare_shape_dtype
)
a
=
scalar
(
"a"
)
a
=
scalar
(
"a"
)
x
=
ptb
.
alloc
(
a
,
20
)
x
=
ptb
.
alloc
(
a
,
20
)
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
10.0
])
compare_jax_and_py
(
[
a
],
[
x
]
,
[
10.0
])
a
=
vector
(
"a"
)
a
=
vector
(
"a"
)
x
=
ptb
.
alloc
(
a
,
20
,
10
)
x
=
ptb
.
alloc
(
a
,
20
,
10
)
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
ones
(
10
,
dtype
=
config
.
floatX
)])
compare_jax_and_py
(
[
a
],
[
x
]
,
[
np
.
ones
(
10
,
dtype
=
config
.
floatX
)])
def
test_alloc_runtime_broadcast
():
def
test_alloc_runtime_broadcast
():
...
@@ -59,34 +50,31 @@ def test_alloc_runtime_broadcast():
...
@@ -59,34 +50,31 @@ def test_alloc_runtime_broadcast():
def
test_jax_MakeVector
():
def
test_jax_MakeVector
():
x
=
ptb
.
make_vector
(
1
,
2
,
3
)
x
=
ptb
.
make_vector
(
1
,
2
,
3
)
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
[],
[
x
]
,
[])
def
test_arange
():
def
test_arange
():
out
=
ptb
.
arange
(
1
,
10
,
2
)
out
=
ptb
.
arange
(
1
,
10
,
2
)
fgraph
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
(
fgraph
,
[])
compare_jax_and_py
(
[],
[
out
]
,
[])
def
test_arange_of_shape
():
def
test_arange_of_shape
():
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
out
=
ptb
.
arange
(
1
,
x
.
shape
[
-
1
],
2
)
out
=
ptb
.
arange
(
1
,
x
.
shape
[
-
1
],
2
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
zeros
((
5
,))],
jax_mode
=
"JAX"
)
compare_jax_and_py
(
fgraph
,
[
np
.
zeros
((
5
,))],
jax_mode
=
"JAX"
)
def
test_arange_nonconcrete
():
def
test_arange_nonconcrete
():
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
a
=
scalar
(
"a"
)
a
=
scalar
(
"a"
)
a
.
tag
.
test_value
=
10
a
_
test_value
=
10
out
=
ptb
.
arange
(
a
)
out
=
ptb
.
arange
(
a
)
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_jax_Join
():
def
test_jax_Join
():
...
@@ -94,16 +82,17 @@ def test_jax_Join():
...
@@ -94,16 +82,17 @@ def test_jax_Join():
b
=
matrix
(
"b"
)
b
=
matrix
(
"b"
)
x
=
ptb
.
join
(
0
,
a
,
b
)
x
=
ptb
.
join
(
0
,
a
,
b
)
x_fg
=
FunctionGraph
([
a
,
b
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
],
],
)
)
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
]]
.
astype
(
config
.
floatX
),
...
@@ -111,16 +100,17 @@ def test_jax_Join():
...
@@ -111,16 +100,17 @@ def test_jax_Join():
)
)
x
=
ptb
.
join
(
1
,
a
,
b
)
x
=
ptb
.
join
(
1
,
a
,
b
)
x_fg
=
FunctionGraph
([
a
,
b
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
],
],
)
)
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
...
@@ -132,9 +122,9 @@ class TestJaxSplit:
...
@@ -132,9 +122,9 @@ class TestJaxSplit:
def
test_basic
(
self
):
def
test_basic
(
self
):
a
=
matrix
(
"a"
)
a
=
matrix
(
"a"
)
a_splits
=
ptb
.
split
(
a
,
splits_size
=
[
1
,
2
,
3
],
n_splits
=
3
,
axis
=
0
)
a_splits
=
ptb
.
split
(
a
,
splits_size
=
[
1
,
2
,
3
],
n_splits
=
3
,
axis
=
0
)
fg
=
FunctionGraph
([
a
],
a_splits
)
compare_jax_and_py
(
compare_jax_and_py
(
fg
,
[
a
],
a_splits
,
[
[
np
.
zeros
((
6
,
4
))
.
astype
(
config
.
floatX
),
np
.
zeros
((
6
,
4
))
.
astype
(
config
.
floatX
),
],
],
...
@@ -142,9 +132,9 @@ class TestJaxSplit:
...
@@ -142,9 +132,9 @@ class TestJaxSplit:
a
=
matrix
(
"a"
,
shape
=
(
6
,
None
))
a
=
matrix
(
"a"
,
shape
=
(
6
,
None
))
a_splits
=
ptb
.
split
(
a
,
splits_size
=
[
2
,
a
.
shape
[
0
]
-
2
],
n_splits
=
2
,
axis
=
0
)
a_splits
=
ptb
.
split
(
a
,
splits_size
=
[
2
,
a
.
shape
[
0
]
-
2
],
n_splits
=
2
,
axis
=
0
)
fg
=
FunctionGraph
([
a
],
a_splits
)
compare_jax_and_py
(
compare_jax_and_py
(
fg
,
[
a
],
a_splits
,
[
[
np
.
zeros
((
6
,
4
))
.
astype
(
config
.
floatX
),
np
.
zeros
((
6
,
4
))
.
astype
(
config
.
floatX
),
],
],
...
@@ -207,15 +197,14 @@ class TestJaxSplit:
...
@@ -207,15 +197,14 @@ class TestJaxSplit:
def
test_jax_eye
():
def
test_jax_eye
():
"""Tests jaxification of the Eye operator"""
"""Tests jaxification of the Eye operator"""
out
=
ptb
.
eye
(
3
)
out
=
ptb
.
eye
(
3
)
out_fg
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out
]
,
[])
def
test_tri
():
def
test_tri
():
out
=
ptb
.
tri
(
10
,
10
,
0
)
out
=
ptb
.
tri
(
10
,
10
,
0
)
fgraph
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
(
fgraph
,
[])
compare_jax_and_py
(
[],
[
out
]
,
[])
@pytest.mark.skipif
(
@pytest.mark.skipif
(
...
@@ -230,14 +219,13 @@ def test_tri_nonconcrete():
...
@@ -230,14 +219,13 @@ def test_tri_nonconcrete():
scalar
(
"n"
,
dtype
=
"int64"
),
scalar
(
"n"
,
dtype
=
"int64"
),
scalar
(
"k"
,
dtype
=
"int64"
),
scalar
(
"k"
,
dtype
=
"int64"
),
)
)
m
.
tag
.
test_value
=
10
m
_
test_value
=
10
n
.
tag
.
test_value
=
10
n
_
test_value
=
10
k
.
tag
.
test_value
=
0
k
_
test_value
=
0
out
=
ptb
.
tri
(
m
,
n
,
k
)
out
=
ptb
.
tri
(
m
,
n
,
k
)
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but
# the error handler raises an Attribute error first, so that's what this test needs to pass
# the error handler raises an Attribute error first, so that's what this test needs to pass
with
pytest
.
raises
(
AttributeError
):
with
pytest
.
raises
(
AttributeError
):
fgraph
=
FunctionGraph
([
m
,
n
,
k
],
[
out
])
compare_jax_and_py
([
m
,
n
,
k
],
[
out
],
[
m_test_value
,
n_test_value
,
k_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
tests/link/numba/test_basic.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/numba/test_blockwise.py
浏览文件 @
cc8c4992
...
@@ -27,7 +27,8 @@ def test_blockwise(core_op, shape_opt):
...
@@ -27,7 +27,8 @@ def test_blockwise(core_op, shape_opt):
)
)
x_test
=
np
.
eye
(
3
)
*
np
.
arange
(
1
,
6
)[:,
None
,
None
]
x_test
=
np
.
eye
(
3
)
*
np
.
arange
(
1
,
6
)[:,
None
,
None
]
compare_numba_and_py
(
compare_numba_and_py
(
([
x
],
outs
),
[
x
],
outs
,
[
x_test
],
[
x_test
],
numba_mode
=
mode
,
numba_mode
=
mode
,
eval_obj_mode
=
False
,
eval_obj_mode
=
False
,
...
...
tests/link/numba/test_elemwise.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/numba/test_extra_ops.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/numba/test_nlinalg.py
浏览文件 @
cc8c4992
...
@@ -4,11 +4,8 @@ import numpy as np
...
@@ -4,11 +4,8 @@ import numpy as np
import
pytest
import
pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.compile.sharedvalue
import
SharedVariable
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
nlinalg
from
pytensor.tensor
import
nlinalg
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
set_test_value
from
tests.link.numba.test_basic
import
compare_numba_and_py
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
...
@@ -18,14 +15,14 @@ rng = np.random.default_rng(42849)
...
@@ -18,14 +15,14 @@ rng = np.random.default_rng(42849)
"x, exc"
,
"x, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
poisson
(
size
=
(
3
,
3
))
.
astype
(
"int64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
poisson
(
size
=
(
3
,
3
))
.
astype
(
"int64"
)),
),
),
...
@@ -34,18 +31,15 @@ rng = np.random.default_rng(42849)
...
@@ -34,18 +31,15 @@ rng = np.random.default_rng(42849)
],
],
)
)
def
test_Det
(
x
,
exc
):
def
test_Det
(
x
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
Det
()(
x
)
g
=
nlinalg
.
Det
()(
x
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -53,14 +47,14 @@ def test_Det(x, exc):
...
@@ -53,14 +47,14 @@ def test_Det(x, exc):
"x, exc"
,
"x, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
poisson
(
size
=
(
3
,
3
))
.
astype
(
"int64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
poisson
(
size
=
(
3
,
3
))
.
astype
(
"int64"
)),
),
),
...
@@ -69,18 +63,15 @@ def test_Det(x, exc):
...
@@ -69,18 +63,15 @@ def test_Det(x, exc):
],
],
)
)
def
test_SLogDet
(
x
,
exc
):
def
test_SLogDet
(
x
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
SLogDet
()(
x
)
g
=
nlinalg
.
SLogDet
()(
x
)
g_fg
=
FunctionGraph
(
outputs
=
g
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -112,21 +103,21 @@ y = np.array(
...
@@ -112,21 +103,21 @@ y = np.array(
"x, exc"
,
"x, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
x
),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
x
),
),
),
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
y
),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
y
),
),
),
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -137,22 +128,15 @@ y = np.array(
...
@@ -137,22 +128,15 @@ y = np.array(
],
],
)
)
def
test_Eig
(
x
,
exc
):
def
test_Eig
(
x
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
Eig
()(
x
)
g
=
nlinalg
.
Eig
()(
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -160,7 +144,7 @@ def test_Eig(x, exc):
...
@@ -160,7 +144,7 @@ def test_Eig(x, exc):
"x, uplo, exc"
,
"x, uplo, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -168,7 +152,7 @@ def test_Eig(x, exc):
...
@@ -168,7 +152,7 @@ def test_Eig(x, exc):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -180,22 +164,15 @@ def test_Eig(x, exc):
...
@@ -180,22 +164,15 @@ def test_Eig(x, exc):
],
],
)
)
def
test_Eigh
(
x
,
uplo
,
exc
):
def
test_Eigh
(
x
,
uplo
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
Eigh
(
uplo
)(
x
)
g
=
nlinalg
.
Eigh
(
uplo
)(
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -204,7 +181,7 @@ def test_Eigh(x, uplo, exc):
...
@@ -204,7 +181,7 @@ def test_Eigh(x, uplo, exc):
[
[
(
(
nlinalg
.
MatrixInverse
,
nlinalg
.
MatrixInverse
,
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -213,7 +190,7 @@ def test_Eigh(x, uplo, exc):
...
@@ -213,7 +190,7 @@ def test_Eigh(x, uplo, exc):
),
),
(
(
nlinalg
.
MatrixInverse
,
nlinalg
.
MatrixInverse
,
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -224,7 +201,7 @@ def test_Eigh(x, uplo, exc):
...
@@ -224,7 +201,7 @@ def test_Eigh(x, uplo, exc):
),
),
(
(
nlinalg
.
MatrixPinv
,
nlinalg
.
MatrixPinv
,
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -233,7 +210,7 @@ def test_Eigh(x, uplo, exc):
...
@@ -233,7 +210,7 @@ def test_Eigh(x, uplo, exc):
),
),
(
(
nlinalg
.
MatrixPinv
,
nlinalg
.
MatrixPinv
,
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -245,18 +222,15 @@ def test_Eigh(x, uplo, exc):
...
@@ -245,18 +222,15 @@ def test_Eigh(x, uplo, exc):
],
],
)
)
def
test_matrix_inverses
(
op
,
x
,
exc
,
op_args
):
def
test_matrix_inverses
(
op
,
x
,
exc
,
op_args
):
x
,
test_x
=
x
g
=
op
(
*
op_args
)(
x
)
g
=
op
(
*
op_args
)(
x
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -264,7 +238,7 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -264,7 +238,7 @@ def test_matrix_inverses(op, x, exc, op_args):
"x, mode, exc"
,
"x, mode, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -272,7 +246,7 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -272,7 +246,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -280,7 +254,7 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -280,7 +254,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -290,7 +264,7 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -290,7 +264,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -302,22 +276,15 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -302,22 +276,15 @@ def test_matrix_inverses(op, x, exc, op_args):
],
],
)
)
def
test_QRFull
(
x
,
mode
,
exc
):
def
test_QRFull
(
x
,
mode
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
QRFull
(
mode
)(
x
)
g
=
nlinalg
.
QRFull
(
mode
)(
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -325,7 +292,7 @@ def test_QRFull(x, mode, exc):
...
@@ -325,7 +292,7 @@ def test_QRFull(x, mode, exc):
"x, full_matrices, compute_uv, exc"
,
"x, full_matrices, compute_uv, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -334,7 +301,7 @@ def test_QRFull(x, mode, exc):
...
@@ -334,7 +301,7 @@ def test_QRFull(x, mode, exc):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -343,7 +310,7 @@ def test_QRFull(x, mode, exc):
...
@@ -343,7 +310,7 @@ def test_QRFull(x, mode, exc):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -354,7 +321,7 @@ def test_QRFull(x, mode, exc):
...
@@ -354,7 +321,7 @@ def test_QRFull(x, mode, exc):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -367,20 +334,13 @@ def test_QRFull(x, mode, exc):
...
@@ -367,20 +334,13 @@ def test_QRFull(x, mode, exc):
],
],
)
)
def
test_SVD
(
x
,
full_matrices
,
compute_uv
,
exc
):
def
test_SVD
(
x
,
full_matrices
,
compute_uv
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
SVD
(
full_matrices
,
compute_uv
)(
x
)
g
=
nlinalg
.
SVD
(
full_matrices
,
compute_uv
)(
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
tests/link/numba/test_pad.py
浏览文件 @
cc8c4992
...
@@ -3,7 +3,6 @@ import pytest
...
@@ -3,7 +3,6 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.graph
import
FunctionGraph
from
pytensor.tensor.pad
import
PadMode
from
pytensor.tensor.pad
import
PadMode
from
tests.link.numba.test_basic
import
compare_numba_and_py
from
tests.link.numba.test_basic
import
compare_numba_and_py
...
@@ -58,10 +57,10 @@ def test_numba_pad(mode: PadMode, kwargs):
...
@@ -58,10 +57,10 @@ def test_numba_pad(mode: PadMode, kwargs):
x
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
x
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
res
=
pt
.
pad
(
x_pt
,
mode
=
mode
,
pad_width
=
3
,
**
kwargs
)
res
=
pt
.
pad
(
x_pt
,
mode
=
mode
,
pad_width
=
3
,
**
kwargs
)
res_fg
=
FunctionGraph
([
x_pt
],
[
res
])
compare_numba_and_py
(
compare_numba_and_py
(
res_fg
,
[
x_pt
],
[
res
],
[
x
],
[
x
],
assert_fn
=
lambda
x
,
y
:
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
RTOL
,
atol
=
ATOL
),
assert_fn
=
lambda
x
,
y
:
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
RTOL
,
atol
=
ATOL
),
py_mode
=
"FAST_RUN"
,
py_mode
=
"FAST_RUN"
,
...
...
tests/link/numba/test_random.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/numba/test_scalar.py
浏览文件 @
cc8c4992
...
@@ -5,13 +5,10 @@ import pytensor.scalar as ps
...
@@ -5,13 +5,10 @@ import pytensor.scalar as ps
import
pytensor.scalar.basic
as
psb
import
pytensor.scalar.basic
as
psb
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.compile.sharedvalue
import
SharedVariable
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.scalar.basic
import
Composite
from
pytensor.scalar.basic
import
Composite
from
pytensor.tensor
import
tensor
from
pytensor.tensor
import
tensor
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.elemwise
import
Elemwise
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
set_test_value
from
tests.link.numba.test_basic
import
compare_numba_and_py
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
...
@@ -21,48 +18,43 @@ rng = np.random.default_rng(42849)
...
@@ -21,48 +18,43 @@ rng = np.random.default_rng(42849)
"x, y"
,
"x, y"
,
[
[
(
(
set_test_value
(
pt
.
lvector
(),
np
.
arange
(
4
,
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
arange
(
4
,
dtype
=
"int64"
)),
set_test_value
(
pt
.
dvector
(),
np
.
arange
(
4
,
dtype
=
"float64"
)),
(
pt
.
dvector
(),
np
.
arange
(
4
,
dtype
=
"float64"
)),
),
),
(
(
set_test_value
(
pt
.
dmatrix
(),
np
.
arange
(
4
,
dtype
=
"float64"
)
.
reshape
((
2
,
2
))),
(
pt
.
dmatrix
(),
np
.
arange
(
4
,
dtype
=
"float64"
)
.
reshape
((
2
,
2
))),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
4
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
4
,
dtype
=
"int64"
)),
),
),
],
],
)
)
def
test_Second
(
x
,
y
):
def
test_Second
(
x
,
y
):
x
,
x_test
=
x
y
,
y_test
=
y
# We use the `Elemwise`-wrapped version of `Second`
# We use the `Elemwise`-wrapped version of `Second`
g
=
pt
.
second
(
x
,
y
)
g
=
pt
.
second
(
x
,
y
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
,
y
],
[
g
,
i
.
tag
.
test_value
[
x_test
,
y_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"v, min, max"
,
"v, min, max"
,
[
[
(
set_test_value
(
pt
.
scalar
(),
np
.
array
(
10
,
dtype
=
config
.
floatX
)),
3.0
,
7.0
),
((
pt
.
scalar
(),
np
.
array
(
10
,
dtype
=
config
.
floatX
)),
3.0
,
7.0
),
(
set_test_value
(
pt
.
scalar
(),
np
.
array
(
1
,
dtype
=
config
.
floatX
)),
3.0
,
7.0
),
((
pt
.
scalar
(),
np
.
array
(
1
,
dtype
=
config
.
floatX
)),
3.0
,
7.0
),
(
set_test_value
(
pt
.
scalar
(),
np
.
array
(
10
,
dtype
=
config
.
floatX
)),
7.0
,
3.0
),
((
pt
.
scalar
(),
np
.
array
(
10
,
dtype
=
config
.
floatX
)),
7.0
,
3.0
),
],
],
)
)
def
test_Clip
(
v
,
min
,
max
):
def
test_Clip
(
v
,
min
,
max
):
v
,
v_test
=
v
g
=
ps
.
clip
(
v
,
min
,
max
)
g
=
ps
.
clip
(
v
,
min
,
max
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
[
g
],
i
.
tag
.
test_value
[
v_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -100,46 +92,39 @@ def test_Clip(v, min, max):
...
@@ -100,46 +92,39 @@ def test_Clip(v, min, max):
def
test_Composite
(
inputs
,
input_values
,
scalar_fn
):
def
test_Composite
(
inputs
,
input_values
,
scalar_fn
):
composite_inputs
=
[
ps
.
ScalarType
(
config
.
floatX
)(
name
=
i
.
name
)
for
i
in
inputs
]
composite_inputs
=
[
ps
.
ScalarType
(
config
.
floatX
)(
name
=
i
.
name
)
for
i
in
inputs
]
comp_op
=
Elemwise
(
Composite
(
composite_inputs
,
[
scalar_fn
(
*
composite_inputs
)]))
comp_op
=
Elemwise
(
Composite
(
composite_inputs
,
[
scalar_fn
(
*
composite_inputs
)]))
out_fg
=
FunctionGraph
(
inputs
,
[
comp_op
(
*
inputs
)])
compare_numba_and_py
(
inputs
,
[
comp_op
(
*
inputs
)],
input_values
)
compare_numba_and_py
(
out_fg
,
input_values
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"v, dtype"
,
"v, dtype"
,
[
[
(
set_test_value
(
pt
.
fscalar
(),
np
.
array
(
1.0
,
dtype
=
"float32"
)),
psb
.
float64
),
((
pt
.
fscalar
(),
np
.
array
(
1.0
,
dtype
=
"float32"
)),
psb
.
float64
),
(
set_test_value
(
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
"float64"
)),
psb
.
float32
),
((
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
"float64"
)),
psb
.
float32
),
],
],
)
)
def
test_Cast
(
v
,
dtype
):
def
test_Cast
(
v
,
dtype
):
v
,
v_test
=
v
g
=
psb
.
Cast
(
dtype
)(
v
)
g
=
psb
.
Cast
(
dtype
)(
v
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
[
g
],
i
.
tag
.
test_value
[
v_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"v, dtype"
,
"v, dtype"
,
[
[
(
set_test_value
(
pt
.
iscalar
(),
np
.
array
(
10
,
dtype
=
"int32"
)),
psb
.
float64
),
((
pt
.
iscalar
(),
np
.
array
(
10
,
dtype
=
"int32"
)),
psb
.
float64
),
],
],
)
)
def
test_reciprocal
(
v
,
dtype
):
def
test_reciprocal
(
v
,
dtype
):
v
,
v_test
=
v
g
=
psb
.
reciprocal
(
v
)
g
=
psb
.
reciprocal
(
v
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
[
g
],
i
.
tag
.
test_value
[
v_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -156,6 +141,7 @@ def test_isnan(composite):
...
@@ -156,6 +141,7 @@ def test_isnan(composite):
out
=
pt
.
isnan
(
x
)
out
=
pt
.
isnan
(
x
)
compare_numba_and_py
(
compare_numba_and_py
(
([
x
],
[
out
]),
[
x
],
[
out
],
[
np
.
array
([
1
,
0
],
dtype
=
"float64"
)],
[
np
.
array
([
1
,
0
],
dtype
=
"float64"
)],
)
)
tests/link/numba/test_scan.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/numba/test_slinalg.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/numba/test_sparse.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/numba/test_subtensor.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/numba/test_tensor_basic.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/pytorch/test_basic.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/pytorch/test_blas.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/pytorch/test_elemwise.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/pytorch/test_extra_ops.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/pytorch/test_math.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/pytorch/test_nlinalg.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/pytorch/test_shape.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/pytorch/test_sort.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/link/pytorch/test_subtensor.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
tests/tensor/test_extra_ops.py
浏览文件 @
cc8c4992
差异被折叠。
点击展开。
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论