Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
cc8c4992
提交
cc8c4992
authored
11月 28, 2024
作者:
Adv
提交者:
Ricardo Vieira
2月 18, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Stop using FunctionGraph and tag.test_value in linker tests
Co-authored-by:
Adv
<
adhvaithhundi.221ds003@nitk.edu.in
>
上级
51ea1a0b
隐藏空白字符变更
内嵌
并排
正在显示
41 个修改的文件
包含
1098 行增加
和
1597 行删除
+1098
-1597
test_basic.py
tests/link/jax/test_basic.py
+24
-27
test_blas.py
tests/link/jax/test_blas.py
+5
-8
test_blockwise.py
tests/link/jax/test_blockwise.py
+1
-3
test_einsum.py
tests/link/jax/test_einsum.py
+2
-5
test_elemwise.py
tests/link/jax/test_elemwise.py
+23
-33
test_extra_ops.py
tests/link/jax/test_extra_ops.py
+12
-28
test_math.py
tests/link/jax/test_math.py
+19
-15
test_nlinalg.py
tests/link/jax/test_nlinalg.py
+8
-17
test_pad.py
tests/link/jax/test_pad.py
+2
-3
test_random.py
tests/link/jax/test_random.py
+48
-56
test_scalar.py
tests/link/jax/test_scalar.py
+56
-70
test_scan.py
tests/link/jax/test_scan.py
+18
-32
test_shape.py
tests/link/jax/test_shape.py
+15
-24
test_slinalg.py
tests/link/jax/test_slinalg.py
+26
-31
test_sort.py
tests/link/jax/test_sort.py
+1
-3
test_sparse.py
tests/link/jax/test_sparse.py
+1
-3
test_subtensor.py
tests/link/jax/test_subtensor.py
+59
-59
test_tensor_basic.py
tests/link/jax/test_tensor_basic.py
+31
-43
test_basic.py
tests/link/numba/test_basic.py
+130
-181
test_blockwise.py
tests/link/numba/test_blockwise.py
+2
-1
test_elemwise.py
tests/link/numba/test_elemwise.py
+88
-152
test_extra_ops.py
tests/link/numba/test_extra_ops.py
+107
-170
test_nlinalg.py
tests/link/numba/test_nlinalg.py
+50
-90
test_pad.py
tests/link/numba/test_pad.py
+2
-3
test_random.py
tests/link/numba/test_random.py
+66
-85
test_scalar.py
tests/link/numba/test_scalar.py
+31
-45
test_scan.py
tests/link/numba/test_scan.py
+18
-28
test_slinalg.py
tests/link/numba/test_slinalg.py
+11
-9
test_sparse.py
tests/link/numba/test_sparse.py
+1
-1
test_subtensor.py
tests/link/numba/test_subtensor.py
+14
-25
test_tensor_basic.py
tests/link/numba/test_tensor_basic.py
+94
-173
test_basic.py
tests/link/pytorch/test_basic.py
+45
-42
test_blas.py
tests/link/pytorch/test_blas.py
+2
-3
test_elemwise.py
tests/link/pytorch/test_elemwise.py
+23
-27
test_extra_ops.py
tests/link/pytorch/test_extra_ops.py
+12
-13
test_math.py
tests/link/pytorch/test_math.py
+6
-5
test_nlinalg.py
tests/link/pytorch/test_nlinalg.py
+7
-14
test_shape.py
tests/link/pytorch/test_shape.py
+13
-15
test_sort.py
tests/link/pytorch/test_sort.py
+1
-3
test_subtensor.py
tests/link/pytorch/test_subtensor.py
+24
-47
test_extra_ops.py
tests/tensor/test_extra_ops.py
+0
-5
没有找到文件。
tests/link/jax/test_basic.py
浏览文件 @
cc8c4992
...
@@ -7,12 +7,12 @@ import pytest
...
@@ -7,12 +7,12 @@ import pytest
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.compile.mode
import
JAX
,
Mode
from
pytensor.compile.mode
import
JAX
,
Mode
from
pytensor.compile.sharedvalue
import
SharedVariable
,
shared
from
pytensor.compile.sharedvalue
import
shared
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
Op
,
get_test_value
from
pytensor.graph.op
import
Op
from
pytensor.ifelse
import
ifelse
from
pytensor.ifelse
import
ifelse
from
pytensor.link.jax
import
JAXLinker
from
pytensor.link.jax
import
JAXLinker
from
pytensor.raise_op
import
assert_op
from
pytensor.raise_op
import
assert_op
...
@@ -34,25 +34,28 @@ py_mode = Mode(linker="py", optimizer=None)
...
@@ -34,25 +34,28 @@ py_mode = Mode(linker="py", optimizer=None)
def
compare_jax_and_py
(
def
compare_jax_and_py
(
fgraph
:
FunctionGraph
,
graph_inputs
:
Iterable
[
Variable
],
graph_outputs
:
Variable
|
Iterable
[
Variable
],
test_inputs
:
Iterable
,
test_inputs
:
Iterable
,
*
,
assert_fn
:
Callable
|
None
=
None
,
assert_fn
:
Callable
|
None
=
None
,
must_be_device_array
:
bool
=
True
,
must_be_device_array
:
bool
=
True
,
jax_mode
=
jax_mode
,
jax_mode
=
jax_mode
,
py_mode
=
py_mode
,
py_mode
=
py_mode
,
):
):
"""Function to compare python
graph
output and jax compiled output for testing equality
"""Function to compare python
function
output and jax compiled output for testing equality
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to
The inputs and outputs are then passed to this function which then compiles the given function in both
this function which then compiles the graphs in both jax and python, runs the calculation
jax and python, runs the calculation in both and checks if the results are the same
in both and checks if the results are the same
Parameters
Parameters
----------
----------
fgraph: FunctionGraph
graph_inputs:
PyTensor function Graph object
Symbolic inputs to the graph
outputs:
Symbolic outputs of the graph
test_inputs: iter
test_inputs: iter
Numerical inputs for testing the function
graph
Numerical inputs for testing the function
.
assert_fn: func, opt
assert_fn: func, opt
Assert function used to check for equality between python and jax. If not
Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose
provided uses np.testing.assert_allclose
...
@@ -68,8 +71,10 @@ def compare_jax_and_py(
...
@@ -68,8 +71,10 @@ def compare_jax_and_py(
if
assert_fn
is
None
:
if
assert_fn
is
None
:
assert_fn
=
partial
(
np
.
testing
.
assert_allclose
,
rtol
=
1e-4
)
assert_fn
=
partial
(
np
.
testing
.
assert_allclose
,
rtol
=
1e-4
)
fn_inputs
=
[
i
for
i
in
fgraph
.
inputs
if
not
isinstance
(
i
,
SharedVariable
)]
if
any
(
inp
.
owner
is
not
None
for
inp
in
graph_inputs
):
pytensor_jax_fn
=
function
(
fn_inputs
,
fgraph
.
outputs
,
mode
=
jax_mode
)
raise
ValueError
(
"Inputs must be root variables"
)
pytensor_jax_fn
=
function
(
graph_inputs
,
graph_outputs
,
mode
=
jax_mode
)
jax_res
=
pytensor_jax_fn
(
*
test_inputs
)
jax_res
=
pytensor_jax_fn
(
*
test_inputs
)
if
must_be_device_array
:
if
must_be_device_array
:
...
@@ -78,10 +83,10 @@ def compare_jax_and_py(
...
@@ -78,10 +83,10 @@ def compare_jax_and_py(
else
:
else
:
assert
isinstance
(
jax_res
,
jax
.
Array
)
assert
isinstance
(
jax_res
,
jax
.
Array
)
pytensor_py_fn
=
function
(
fn_inputs
,
fgraph
.
outputs
,
mode
=
py_mode
)
pytensor_py_fn
=
function
(
graph_inputs
,
graph_
outputs
,
mode
=
py_mode
)
py_res
=
pytensor_py_fn
(
*
test_inputs
)
py_res
=
pytensor_py_fn
(
*
test_inputs
)
if
len
(
fgraph
.
outputs
)
>
1
:
if
isinstance
(
graph_outputs
,
list
|
tuple
)
:
for
j
,
p
in
zip
(
jax_res
,
py_res
,
strict
=
True
):
for
j
,
p
in
zip
(
jax_res
,
py_res
,
strict
=
True
):
assert_fn
(
j
,
p
)
assert_fn
(
j
,
p
)
else
:
else
:
...
@@ -187,16 +192,14 @@ def test_jax_ifelse():
...
@@ -187,16 +192,14 @@ def test_jax_ifelse():
false_vals
=
np
.
r_
[
-
1
,
-
2
,
-
3
]
false_vals
=
np
.
r_
[
-
1
,
-
2
,
-
3
]
x
=
ifelse
(
np
.
array
(
True
),
true_vals
,
false_vals
)
x
=
ifelse
(
np
.
array
(
True
),
true_vals
,
false_vals
)
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
[],
[
x
]
,
[])
a
=
dscalar
(
"a"
)
a
=
dscalar
(
"a"
)
a
.
tag
.
test_value
=
np
.
array
(
0.2
,
dtype
=
config
.
floatX
)
a
_test
=
np
.
array
(
0.2
,
dtype
=
config
.
floatX
)
x
=
ifelse
(
a
<
0.5
,
true_vals
,
false_vals
)
x
=
ifelse
(
a
<
0.5
,
true_vals
,
false_vals
)
x_fg
=
FunctionGraph
([
a
],
[
x
])
# I.e. False
compare_jax_and_py
(
x_fg
,
[
get_test_value
(
i
)
for
i
in
x_fg
.
inputs
])
compare_jax_and_py
(
[
a
],
[
x
],
[
a_test
])
def
test_jax_checkandraise
():
def
test_jax_checkandraise
():
...
@@ -209,11 +212,6 @@ def test_jax_checkandraise():
...
@@ -209,11 +212,6 @@ def test_jax_checkandraise():
function
((
p
,),
res
,
mode
=
jax_mode
)
function
((
p
,),
res
,
mode
=
jax_mode
)
def
set_test_value
(
x
,
v
):
x
.
tag
.
test_value
=
v
return
x
def
test_OpFromGraph
():
def
test_OpFromGraph
():
x
,
y
,
z
=
matrices
(
"xyz"
)
x
,
y
,
z
=
matrices
(
"xyz"
)
ofg_1
=
OpFromGraph
([
x
,
y
],
[
x
+
y
],
inline
=
False
)
ofg_1
=
OpFromGraph
([
x
,
y
],
[
x
+
y
],
inline
=
False
)
...
@@ -221,10 +219,9 @@ def test_OpFromGraph():
...
@@ -221,10 +219,9 @@ def test_OpFromGraph():
o1
,
o2
=
ofg_2
(
y
,
z
)
o1
,
o2
=
ofg_2
(
y
,
z
)
out
=
ofg_1
(
x
,
o1
)
+
o2
out
=
ofg_1
(
x
,
o1
)
+
o2
out_fg
=
FunctionGraph
([
x
,
y
,
z
],
[
out
])
xv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
xv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
yv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
3
yv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
3
zv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
5
zv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
5
compare_jax_and_py
(
out_fg
,
[
xv
,
yv
,
zv
])
compare_jax_and_py
(
[
x
,
y
,
z
],
[
out
]
,
[
xv
,
yv
,
zv
])
tests/link/jax/test_blas.py
浏览文件 @
cc8c4992
...
@@ -4,8 +4,6 @@ import pytest
...
@@ -4,8 +4,6 @@ import pytest
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.compile.mode
import
Mode
from
pytensor.compile.mode
import
Mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
from
pytensor.link.jax
import
JAXLinker
from
pytensor.link.jax
import
JAXLinker
from
pytensor.tensor
import
blas
as
pt_blas
from
pytensor.tensor
import
blas
as
pt_blas
...
@@ -16,21 +14,20 @@ from tests.link.jax.test_basic import compare_jax_and_py
...
@@ -16,21 +14,20 @@ from tests.link.jax.test_basic import compare_jax_and_py
def
test_jax_BatchedDot
():
def
test_jax_BatchedDot
():
# tensor3 . tensor3
# tensor3 . tensor3
a
=
tensor3
(
"a"
)
a
=
tensor3
(
"a"
)
a
.
tag
.
test_value
=
(
a
_
test_value
=
(
np
.
linspace
(
-
1
,
1
,
10
*
5
*
3
)
.
astype
(
config
.
floatX
)
.
reshape
((
10
,
5
,
3
))
np
.
linspace
(
-
1
,
1
,
10
*
5
*
3
)
.
astype
(
config
.
floatX
)
.
reshape
((
10
,
5
,
3
))
)
)
b
=
tensor3
(
"b"
)
b
=
tensor3
(
"b"
)
b
.
tag
.
test_value
=
(
b
_
test_value
=
(
np
.
linspace
(
1
,
-
1
,
10
*
3
*
2
)
.
astype
(
config
.
floatX
)
.
reshape
((
10
,
3
,
2
))
np
.
linspace
(
1
,
-
1
,
10
*
3
*
2
)
.
astype
(
config
.
floatX
)
.
reshape
((
10
,
3
,
2
))
)
)
out
=
pt_blas
.
BatchedDot
()(
a
,
b
)
out
=
pt_blas
.
BatchedDot
()(
a
,
b
)
fgraph
=
FunctionGraph
([
a
,
b
],
[
out
])
compare_jax_and_py
([
a
,
b
],
[
out
],
[
a_test_value
,
b_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
# A dimension mismatch should raise a TypeError for compatibility
# A dimension mismatch should raise a TypeError for compatibility
inputs
=
[
get_test_value
(
a
)[:
-
1
],
get_test_value
(
b
)
]
inputs
=
[
a_test_value
[:
-
1
],
b_test_value
]
opts
=
RewriteDatabaseQuery
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
opts
=
RewriteDatabaseQuery
(
include
=
[
None
],
exclude
=
[
"cxx_only"
,
"BlasOpt"
])
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
jax_mode
=
Mode
(
JAXLinker
(),
opts
)
pytensor_jax_fn
=
function
(
fgraph
.
inputs
,
fgraph
.
outputs
,
mode
=
jax_mode
)
pytensor_jax_fn
=
function
(
[
a
,
b
],
[
out
]
,
mode
=
jax_mode
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
pytensor_jax_fn
(
*
inputs
)
pytensor_jax_fn
(
*
inputs
)
tests/link/jax/test_blockwise.py
浏览文件 @
cc8c4992
...
@@ -2,7 +2,6 @@ import numpy as np
...
@@ -2,7 +2,6 @@ import numpy as np
import
pytest
import
pytest
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.graph
import
FunctionGraph
from
pytensor.tensor
import
tensor
from
pytensor.tensor
import
tensor
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.math
import
Dot
,
matmul
from
pytensor.tensor.math
import
Dot
,
matmul
...
@@ -32,8 +31,7 @@ def test_matmul(matmul_op):
...
@@ -32,8 +31,7 @@ def test_matmul(matmul_op):
out
=
matmul_op
(
a
,
b
)
out
=
matmul_op
(
a
,
b
)
assert
isinstance
(
out
.
owner
.
op
,
Blockwise
)
assert
isinstance
(
out
.
owner
.
op
,
Blockwise
)
fg
=
FunctionGraph
([
a
,
b
],
[
out
])
fn
,
_
=
compare_jax_and_py
([
a
,
b
],
[
out
],
test_values
)
fn
,
_
=
compare_jax_and_py
(
fg
,
test_values
)
# Check we are not adding any unnecessary stuff
# Check we are not adding any unnecessary stuff
jaxpr
=
str
(
jax
.
make_jaxpr
(
fn
.
vm
.
jit_fn
)(
*
test_values
))
jaxpr
=
str
(
jax
.
make_jaxpr
(
fn
.
vm
.
jit_fn
)(
*
test_values
))
...
...
tests/link/jax/test_einsum.py
浏览文件 @
cc8c4992
...
@@ -2,7 +2,6 @@ import numpy as np
...
@@ -2,7 +2,6 @@ import numpy as np
import
pytest
import
pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.graph
import
FunctionGraph
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -22,8 +21,7 @@ def test_jax_einsum():
...
@@ -22,8 +21,7 @@ def test_jax_einsum():
}
}
x_pt
,
y_pt
,
z_pt
=
(
pt
.
tensor
(
name
,
shape
=
shape
)
for
name
,
shape
in
shapes
.
items
())
x_pt
,
y_pt
,
z_pt
=
(
pt
.
tensor
(
name
,
shape
=
shape
)
for
name
,
shape
in
shapes
.
items
())
out
=
pt
.
einsum
(
subscripts
,
x_pt
,
y_pt
,
z_pt
)
out
=
pt
.
einsum
(
subscripts
,
x_pt
,
y_pt
,
z_pt
)
fg
=
FunctionGraph
([
x_pt
,
y_pt
,
z_pt
],
[
out
])
compare_jax_and_py
([
x_pt
,
y_pt
,
z_pt
],
[
out
],
[
x
,
y
,
z
])
compare_jax_and_py
(
fg
,
[
x
,
y
,
z
])
def
test_ellipsis_einsum
():
def
test_ellipsis_einsum
():
...
@@ -34,5 +32,4 @@ def test_ellipsis_einsum():
...
@@ -34,5 +32,4 @@ def test_ellipsis_einsum():
x_pt
=
pt
.
tensor
(
"x"
,
shape
=
x
.
shape
)
x_pt
=
pt
.
tensor
(
"x"
,
shape
=
x
.
shape
)
y_pt
=
pt
.
tensor
(
"y"
,
shape
=
y
.
shape
)
y_pt
=
pt
.
tensor
(
"y"
,
shape
=
y
.
shape
)
out
=
pt
.
einsum
(
subscripts
,
x_pt
,
y_pt
)
out
=
pt
.
einsum
(
subscripts
,
x_pt
,
y_pt
)
fg
=
FunctionGraph
([
x_pt
,
y_pt
],
[
out
])
compare_jax_and_py
([
x_pt
,
y_pt
],
[
out
],
[
x
,
y
])
compare_jax_and_py
(
fg
,
[
x
,
y
])
tests/link/jax/test_elemwise.py
浏览文件 @
cc8c4992
...
@@ -6,8 +6,6 @@ import pytensor
...
@@ -6,8 +6,6 @@ import pytensor
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.compile
import
get_mode
from
pytensor.compile
import
get_mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.tensor
import
elemwise
as
pt_elemwise
from
pytensor.tensor
import
elemwise
as
pt_elemwise
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
all
as
pt_all
from
pytensor.tensor.math
import
prod
from
pytensor.tensor.math
import
prod
...
@@ -26,22 +24,22 @@ def test_jax_Dimshuffle():
...
@@ -26,22 +24,22 @@ def test_jax_Dimshuffle():
a_pt
=
matrix
(
"a"
)
a_pt
=
matrix
(
"a"
)
x
=
a_pt
.
T
x
=
a_pt
.
T
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
[
a_pt
],
[
x
],
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)]
)
x
=
a_pt
.
dimshuffle
([
0
,
1
,
"x"
])
x
=
a_pt
.
dimshuffle
([
0
,
1
,
"x"
])
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
[
a_pt
],
[
x
],
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)]
)
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
x
=
a_pt
.
dimshuffle
((
0
,))
x
=
a_pt
.
dimshuffle
((
0
,))
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
([
a_pt
],
[
x
],
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
x
=
pt_elemwise
.
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
0
,))(
a_pt
)
x
=
pt_elemwise
.
DimShuffle
(
input_ndim
=
2
,
new_order
=
(
0
,))(
a_pt
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
([
a_pt
],
[
x
],
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
def
test_jax_CAReduce
():
def
test_jax_CAReduce
():
...
@@ -49,64 +47,58 @@ def test_jax_CAReduce():
...
@@ -49,64 +47,58 @@ def test_jax_CAReduce():
a_pt
.
tag
.
test_value
=
np
.
r_
[
1
,
2
,
3
]
.
astype
(
config
.
floatX
)
a_pt
.
tag
.
test_value
=
np
.
r_
[
1
,
2
,
3
]
.
astype
(
config
.
floatX
)
x
=
pt_sum
(
a_pt
,
axis
=
None
)
x
=
pt_sum
(
a_pt
,
axis
=
None
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1
,
2
,
3
]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
a_pt
],
[
x
]
,
[
np
.
r_
[
1
,
2
,
3
]
.
astype
(
config
.
floatX
)])
a_pt
=
matrix
(
"a"
)
a_pt
=
matrix
(
"a"
)
a_pt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)
a_pt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)
x
=
pt_sum
(
a_pt
,
axis
=
0
)
x
=
pt_sum
(
a_pt
,
axis
=
0
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
a_pt
],
[
x
]
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
x
=
pt_sum
(
a_pt
,
axis
=
1
)
x
=
pt_sum
(
a_pt
,
axis
=
1
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
a_pt
],
[
x
]
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
a_pt
=
matrix
(
"a"
)
a_pt
=
matrix
(
"a"
)
a_pt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)
a_pt
.
tag
.
test_value
=
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)
x
=
prod
(
a_pt
,
axis
=
0
)
x
=
prod
(
a_pt
,
axis
=
0
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
a_pt
],
[
x
]
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
x
=
pt_all
(
a_pt
)
x
=
pt_all
(
a_pt
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
a_pt
],
[
x
]
,
[
np
.
c_
[[
1
,
2
,
3
],
[
1
,
2
,
3
]]
.
astype
(
config
.
floatX
)])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
def
test_softmax
(
axis
):
def
test_softmax
(
axis
):
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
x
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
x
_
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
out
=
softmax
(
x
,
axis
=
axis
)
out
=
softmax
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
x_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
def
test_logsoftmax
(
axis
):
def
test_logsoftmax
(
axis
):
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
x
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
x
_
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
out
=
log_softmax
(
x
,
axis
=
axis
)
out
=
log_softmax
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
compare_jax_and_py
(
[
x
],
[
out
],
[
x_test_value
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
def
test_softmax_grad
(
axis
):
def
test_softmax_grad
(
axis
):
dy
=
matrix
(
"dy"
)
dy
=
matrix
(
"dy"
)
dy
.
tag
.
test_value
=
np
.
array
([[
1
,
1
,
1
],
[
0
,
0
,
0
]],
dtype
=
config
.
floatX
)
dy
_
test_value
=
np
.
array
([[
1
,
1
,
1
],
[
0
,
0
,
0
]],
dtype
=
config
.
floatX
)
sm
=
matrix
(
"sm"
)
sm
=
matrix
(
"sm"
)
sm
.
tag
.
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
sm
_
test_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
out
=
SoftmaxGrad
(
axis
=
axis
)(
dy
,
sm
)
out
=
SoftmaxGrad
(
axis
=
axis
)(
dy
,
sm
)
fgraph
=
FunctionGraph
([
dy
,
sm
],
[
out
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
compare_jax_and_py
(
[
dy
,
sm
],
[
out
],
[
dy_test_value
,
sm_test_value
])
@pytest.mark.parametrize
(
"size"
,
[(
10
,
10
),
(
1000
,
1000
)])
@pytest.mark.parametrize
(
"size"
,
[(
10
,
10
),
(
1000
,
1000
)])
...
@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
...
@@ -134,6 +126,4 @@ def test_logsumexp_benchmark(size, axis, benchmark):
def
test_multiple_input_multiply
():
def
test_multiple_input_multiply
():
x
,
y
,
z
=
vectors
(
"xyz"
)
x
,
y
,
z
=
vectors
(
"xyz"
)
out
=
pt
.
mul
(
x
,
y
,
z
)
out
=
pt
.
mul
(
x
,
y
,
z
)
compare_jax_and_py
([
x
,
y
,
z
],
[
out
],
test_inputs
=
[[
1.5
],
[
2.5
],
[
3.5
]])
fg
=
FunctionGraph
(
outputs
=
[
out
],
clone
=
False
)
compare_jax_and_py
(
fg
,
[[
1.5
],
[
2.5
],
[
3.5
]])
tests/link/jax/test_extra_ops.py
浏览文件 @
cc8c4992
...
@@ -3,8 +3,6 @@ import pytest
...
@@ -3,8 +3,6 @@ import pytest
import
pytensor.tensor.basic
as
ptb
import
pytensor.tensor.basic
as
ptb
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.tensor
import
extra_ops
as
pt_extra_ops
from
pytensor.tensor
import
extra_ops
as
pt_extra_ops
from
pytensor.tensor.sort
import
argsort
from
pytensor.tensor.sort
import
argsort
from
pytensor.tensor.type
import
matrix
,
tensor
from
pytensor.tensor.type
import
matrix
,
tensor
...
@@ -19,57 +17,45 @@ def test_extra_ops():
...
@@ -19,57 +17,45 @@ def test_extra_ops():
a_test
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
a_test
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
out
=
pt_extra_ops
.
cumsum
(
a
,
axis
=
0
)
out
=
pt_extra_ops
.
cumsum
(
a
,
axis
=
0
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
out
=
pt_extra_ops
.
cumprod
(
a
,
axis
=
1
)
out
=
pt_extra_ops
.
cumprod
(
a
,
axis
=
1
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
out
=
pt_extra_ops
.
diff
(
a
,
n
=
2
,
axis
=
1
)
out
=
pt_extra_ops
.
diff
(
a
,
n
=
2
,
axis
=
1
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
out
=
pt_extra_ops
.
repeat
(
a
,
(
3
,
3
),
axis
=
1
)
out
=
pt_extra_ops
.
repeat
(
a
,
(
3
,
3
),
axis
=
1
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
c
=
ptb
.
as_tensor
(
5
)
c
=
ptb
.
as_tensor
(
5
)
out
=
pt_extra_ops
.
fill_diagonal
(
a
,
c
)
out
=
pt_extra_ops
.
fill_diagonal
(
a
,
c
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
out
=
pt_extra_ops
.
fill_diagonal_offset
(
a
,
c
,
c
)
out
=
pt_extra_ops
.
fill_diagonal_offset
(
a
,
c
,
c
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
out
=
pt_extra_ops
.
Unique
(
axis
=
1
)(
a
)
out
=
pt_extra_ops
.
Unique
(
axis
=
1
)(
a
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
indices
=
np
.
arange
(
np
.
prod
((
3
,
4
)))
indices
=
np
.
arange
(
np
.
prod
((
3
,
4
)))
out
=
pt_extra_ops
.
unravel_index
(
indices
,
(
3
,
4
),
order
=
"C"
)
out
=
pt_extra_ops
.
unravel_index
(
indices
,
(
3
,
4
),
order
=
"C"
)
fgraph
=
FunctionGraph
([],
out
)
compare_jax_and_py
([],
out
,
[],
must_be_device_array
=
False
)
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
],
must_be_device_array
=
False
)
v
=
ptb
.
as_tensor_variable
(
6.0
)
v
=
ptb
.
as_tensor_variable
(
6.0
)
sorted_idx
=
argsort
(
a
.
ravel
())
sorted_idx
=
argsort
(
a
.
ravel
())
out
=
pt_extra_ops
.
searchsorted
(
a
.
ravel
()[
sorted_idx
],
v
)
out
=
pt_extra_ops
.
searchsorted
(
a
.
ravel
()[
sorted_idx
],
v
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
def
test_bartlett_dynamic_shape
():
def
test_bartlett_dynamic_shape
():
c
=
tensor
(
shape
=
(),
dtype
=
int
)
c
=
tensor
(
shape
=
(),
dtype
=
int
)
out
=
pt_extra_ops
.
bartlett
(
c
)
out
=
pt_extra_ops
.
bartlett
(
c
)
fgraph
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
([],
[
out
],
[
np
.
array
(
5
)])
compare_jax_and_py
(
fgraph
,
[
np
.
array
(
5
)])
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
...
@@ -79,8 +65,7 @@ def test_ravel_multi_index_dynamic_shape():
...
@@ -79,8 +65,7 @@ def test_ravel_multi_index_dynamic_shape():
x
=
tensor
(
shape
=
(
None
,),
dtype
=
int
)
x
=
tensor
(
shape
=
(
None
,),
dtype
=
int
)
y
=
tensor
(
shape
=
(
None
,),
dtype
=
int
)
y
=
tensor
(
shape
=
(
None
,),
dtype
=
int
)
out
=
pt_extra_ops
.
ravel_multi_index
((
x
,
y
),
(
3
,
4
))
out
=
pt_extra_ops
.
ravel_multi_index
((
x
,
y
),
(
3
,
4
))
fgraph
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
([],
[
out
],
[
x_test
,
y_test
])
compare_jax_and_py
(
fgraph
,
[
x_test
,
y_test
])
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
@pytest.mark.xfail
(
reason
=
"Jitted JAX does not support dynamic shapes"
)
...
@@ -89,5 +74,4 @@ def test_unique_dynamic_shape():
...
@@ -89,5 +74,4 @@ def test_unique_dynamic_shape():
a_test
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
a_test
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
out
=
pt_extra_ops
.
Unique
()(
a
)
out
=
pt_extra_ops
.
Unique
()(
a
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test
])
compare_jax_and_py
(
fgraph
,
[
a_test
])
tests/link/jax/test_math.py
浏览文件 @
cc8c4992
...
@@ -2,8 +2,6 @@ import numpy as np
...
@@ -2,8 +2,6 @@ import numpy as np
import
pytest
import
pytest
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.tensor.math
import
Argmax
,
Max
,
maximum
from
pytensor.tensor.math
import
Argmax
,
Max
,
maximum
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.math
import
max
as
pt_max
from
pytensor.tensor.type
import
dvector
,
matrix
,
scalar
,
vector
from
pytensor.tensor.type
import
dvector
,
matrix
,
scalar
,
vector
...
@@ -20,33 +18,39 @@ def test_jax_max_and_argmax():
...
@@ -20,33 +18,39 @@ def test_jax_max_and_argmax():
mx
=
Max
([
0
])(
x
)
mx
=
Max
([
0
])(
x
)
amx
=
Argmax
([
0
])(
x
)
amx
=
Argmax
([
0
])(
x
)
out
=
mx
*
amx
out
=
mx
*
amx
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
r_
[
1
,
2
]])
compare_jax_and_py
(
out_fg
,
[
np
.
r_
[
1
,
2
]])
def
test_dot
():
def
test_dot
():
y
=
vector
(
"y"
)
y
=
vector
(
"y"
)
y
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
y
_
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
x
.
tag
.
test_value
=
np
.
r_
[
3.0
,
4.0
]
.
astype
(
config
.
floatX
)
x
_
test_value
=
np
.
r_
[
3.0
,
4.0
]
.
astype
(
config
.
floatX
)
A
=
matrix
(
"A"
)
A
=
matrix
(
"A"
)
A
.
tag
.
test_value
=
np
.
empty
((
2
,
2
),
dtype
=
config
.
floatX
)
A
_
test_value
=
np
.
empty
((
2
,
2
),
dtype
=
config
.
floatX
)
alpha
=
scalar
(
"alpha"
)
alpha
=
scalar
(
"alpha"
)
alpha
.
tag
.
test_value
=
np
.
array
(
3.0
,
dtype
=
config
.
floatX
)
alpha
_
test_value
=
np
.
array
(
3.0
,
dtype
=
config
.
floatX
)
beta
=
scalar
(
"beta"
)
beta
=
scalar
(
"beta"
)
beta
.
tag
.
test_value
=
np
.
array
(
5.0
,
dtype
=
config
.
floatX
)
beta
_
test_value
=
np
.
array
(
5.0
,
dtype
=
config
.
floatX
)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
# leave the expression alone.
out
=
y
.
dot
(
alpha
*
A
)
.
dot
(
x
)
+
beta
*
y
out
=
y
.
dot
(
alpha
*
A
)
.
dot
(
x
)
+
beta
*
y
fgraph
=
FunctionGraph
([
y
,
x
,
A
,
alpha
,
beta
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
[
y
,
x
,
A
,
alpha
,
beta
],
out
,
[
y_test_value
,
x_test_value
,
A_test_value
,
alpha_test_value
,
beta_test_value
,
],
)
out
=
maximum
(
y
,
x
)
out
=
maximum
(
y
,
x
)
fgraph
=
FunctionGraph
([
y
,
x
],
[
out
])
compare_jax_and_py
([
y
,
x
],
[
out
],
[
y_test_value
,
x_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
pt_max
(
y
)
out
=
pt_max
(
y
)
fgraph
=
FunctionGraph
([
y
],
[
out
])
compare_jax_and_py
([
y
],
[
out
],
[
y_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
tests/link/jax/test_nlinalg.py
浏览文件 @
cc8c4992
...
@@ -3,7 +3,6 @@ import pytest
...
@@ -3,7 +3,6 @@ import pytest
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor.type
import
matrix
from
pytensor.tensor.type
import
matrix
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -21,41 +20,34 @@ def test_jax_basic_multiout():
...
@@ -21,41 +20,34 @@ def test_jax_basic_multiout():
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
outs
=
pt_nlinalg
.
eig
(
x
)
outs
=
pt_nlinalg
.
eig
(
x
)
out_fg
=
FunctionGraph
([
x
],
outs
)
def
assert_fn
(
x
,
y
):
def
assert_fn
(
x
,
y
):
np
.
testing
.
assert_allclose
(
x
.
astype
(
config
.
floatX
),
y
,
rtol
=
1e-3
)
np
.
testing
.
assert_allclose
(
x
.
astype
(
config
.
floatX
),
y
,
rtol
=
1e-3
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
[
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
eigh
(
x
)
outs
=
pt_nlinalg
.
eigh
(
x
)
out_fg
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"full"
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"full"
)
out_fg
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"reduced"
)
outs
=
pt_nlinalg
.
qr
(
x
,
mode
=
"reduced"
)
out_fg
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
svd
(
x
)
outs
=
pt_nlinalg
.
svd
(
x
)
out_fg
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
outs
=
pt_nlinalg
.
slogdet
(
x
)
outs
=
pt_nlinalg
.
slogdet
(
x
)
out_fg
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
compare_jax_and_py
(
out_fg
,
[
X
.
astype
(
config
.
floatX
)],
assert_fn
=
assert_fn
)
def
test_pinv
():
def
test_pinv
():
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
x_inv
=
pt_nlinalg
.
pinv
(
x
)
x_inv
=
pt_nlinalg
.
pinv
(
x
)
fgraph
=
FunctionGraph
([
x
],
[
x_inv
])
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
compare_jax_and_py
(
fgraph
,
[
x_np
])
compare_jax_and_py
(
[
x
],
[
x_inv
]
,
[
x_np
])
def
test_pinv_hermitian
():
def
test_pinv_hermitian
():
...
@@ -94,8 +86,7 @@ def test_kron():
...
@@ -94,8 +86,7 @@ def test_kron():
y
=
matrix
(
"y"
)
y
=
matrix
(
"y"
)
z
=
pt_nlinalg
.
kron
(
x
,
y
)
z
=
pt_nlinalg
.
kron
(
x
,
y
)
fgraph
=
FunctionGraph
([
x
,
y
],
[
z
])
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
y_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
y_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
compare_jax_and_py
(
fgraph
,
[
x_np
,
y_np
])
compare_jax_and_py
(
[
x
,
y
],
[
z
]
,
[
x_np
,
y_np
])
tests/link/jax/test_pad.py
浏览文件 @
cc8c4992
...
@@ -3,7 +3,6 @@ import pytest
...
@@ -3,7 +3,6 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.graph
import
FunctionGraph
from
pytensor.tensor.pad
import
PadMode
from
pytensor.tensor.pad
import
PadMode
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -53,10 +52,10 @@ def test_jax_pad(mode: PadMode, kwargs):
...
@@ -53,10 +52,10 @@ def test_jax_pad(mode: PadMode, kwargs):
x
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
x
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
res
=
pt
.
pad
(
x_pt
,
mode
=
mode
,
pad_width
=
3
,
**
kwargs
)
res
=
pt
.
pad
(
x_pt
,
mode
=
mode
,
pad_width
=
3
,
**
kwargs
)
res_fg
=
FunctionGraph
([
x_pt
],
[
res
])
compare_jax_and_py
(
compare_jax_and_py
(
res_fg
,
[
x_pt
],
[
res
],
[
x
],
[
x
],
assert_fn
=
lambda
x
,
y
:
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
RTOL
,
atol
=
ATOL
),
assert_fn
=
lambda
x
,
y
:
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
RTOL
,
atol
=
ATOL
),
py_mode
=
"FAST_RUN"
,
py_mode
=
"FAST_RUN"
,
...
...
tests/link/jax/test_random.py
浏览文件 @
cc8c4992
...
@@ -7,13 +7,11 @@ import pytensor.tensor as pt
...
@@ -7,13 +7,11 @@ import pytensor.tensor as pt
import
pytensor.tensor.random.basic
as
ptr
import
pytensor.tensor.random.basic
as
ptr
from
pytensor
import
clone_replace
from
pytensor
import
clone_replace
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.compile.sharedvalue
import
SharedVariable
,
shared
from
pytensor.compile.sharedvalue
import
shared
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor.random.basic
import
RandomVariable
from
pytensor.tensor.random.basic
import
RandomVariable
from
pytensor.tensor.random.type
import
RandomType
from
pytensor.tensor.random.type
import
RandomType
from
pytensor.tensor.random.utils
import
RandomStream
from
pytensor.tensor.random.utils
import
RandomStream
from
tests.link.jax.test_basic
import
compare_jax_and_py
,
jax_mode
,
set_test_value
from
tests.link.jax.test_basic
import
compare_jax_and_py
,
jax_mode
from
tests.tensor.random.test_basic
import
(
from
tests.tensor.random.test_basic
import
(
batched_permutation_tester
,
batched_permutation_tester
,
batched_unweighted_choice_without_replacement_tester
,
batched_unweighted_choice_without_replacement_tester
,
...
@@ -147,11 +145,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -147,11 +145,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
beta
,
ptr
.
beta
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -163,11 +161,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -163,11 +161,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
cauchy
,
ptr
.
cauchy
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -179,7 +177,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -179,7 +177,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
exponential
,
ptr
.
exponential
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -191,11 +189,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -191,11 +189,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
_gamma
,
ptr
.
_gamma
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
0.5
,
3.0
],
dtype
=
np
.
float64
),
np
.
array
([
0.5
,
3.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -207,11 +205,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -207,11 +205,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
gumbel
,
ptr
.
gumbel
,
[
[
set_test_value
(
(
pt
.
lvector
(),
pt
.
lvector
(),
np
.
array
([
1
,
2
],
dtype
=
np
.
int64
),
np
.
array
([
1
,
2
],
dtype
=
np
.
int64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -223,8 +221,8 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -223,8 +221,8 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
laplace
,
ptr
.
laplace
,
[
[
set_test_value
(
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
)),
(
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
)),
set_test_value
(
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
)),
(
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
)),
],
],
(
2
,),
(
2
,),
"laplace"
,
"laplace"
,
...
@@ -233,11 +231,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -233,11 +231,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
logistic
,
ptr
.
logistic
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -249,11 +247,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -249,11 +247,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
lognormal
,
ptr
.
lognormal
,
[
[
set_test_value
(
(
pt
.
lvector
(),
pt
.
lvector
(),
np
.
array
([
0
,
0
],
dtype
=
np
.
int64
),
np
.
array
([
0
,
0
],
dtype
=
np
.
int64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -265,11 +263,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -265,11 +263,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
normal
,
ptr
.
normal
,
[
[
set_test_value
(
(
pt
.
lvector
(),
pt
.
lvector
(),
np
.
array
([
1
,
2
],
dtype
=
np
.
int64
),
np
.
array
([
1
,
2
],
dtype
=
np
.
int64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -281,11 +279,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -281,11 +279,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
pareto
,
ptr
.
pareto
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
2.0
,
10.0
],
dtype
=
np
.
float64
),
np
.
array
([
2.0
,
10.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -297,7 +295,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -297,7 +295,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
poisson
,
ptr
.
poisson
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
100000.0
,
200000.0
],
dtype
=
np
.
float64
),
np
.
array
([
100000.0
,
200000.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -309,11 +307,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -309,11 +307,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
integers
,
ptr
.
integers
,
[
[
set_test_value
(
(
pt
.
lscalar
(),
pt
.
lscalar
(),
np
.
array
(
0
,
dtype
=
np
.
int64
),
np
.
array
(
0
,
dtype
=
np
.
int64
),
),
),
set_test_value
(
# high-value necessary since test on cdf
(
# high-value necessary since test on cdf
pt
.
lscalar
(),
pt
.
lscalar
(),
np
.
array
(
1000
,
dtype
=
np
.
int64
),
np
.
array
(
1000
,
dtype
=
np
.
int64
),
),
),
...
@@ -332,15 +330,15 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -332,15 +330,15 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
t
,
ptr
.
t
,
[
[
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
2.0
,
dtype
=
np
.
float64
),
np
.
array
(
2.0
,
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -352,11 +350,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -352,11 +350,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
uniform
,
ptr
.
uniform
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1000.0
,
dtype
=
np
.
float64
),
np
.
array
(
1000.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -368,11 +366,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -368,11 +366,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
halfnormal
,
ptr
.
halfnormal
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
-
1.0
,
200.0
],
dtype
=
np
.
float64
),
np
.
array
([
-
1.0
,
200.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1000.0
,
dtype
=
np
.
float64
),
np
.
array
(
1000.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -384,11 +382,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -384,11 +382,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
invgamma
,
ptr
.
invgamma
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
10.4
,
2.8
],
dtype
=
np
.
float64
),
np
.
array
([
10.4
,
2.8
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
3.4
,
7.3
],
dtype
=
np
.
float64
),
np
.
array
([
3.4
,
7.3
],
dtype
=
np
.
float64
),
),
),
...
@@ -400,7 +398,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -400,7 +398,7 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
chisquare
,
ptr
.
chisquare
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
2.4
,
4.9
],
dtype
=
np
.
float64
),
np
.
array
([
2.4
,
4.9
],
dtype
=
np
.
float64
),
),
),
...
@@ -412,15 +410,15 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -412,15 +410,15 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
gengamma
,
ptr
.
gengamma
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
10.4
,
2.8
],
dtype
=
np
.
float64
),
np
.
array
([
10.4
,
2.8
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
3.4
,
7.3
],
dtype
=
np
.
float64
),
np
.
array
([
3.4
,
7.3
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
0.9
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
0.9
,
2.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -432,11 +430,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -432,11 +430,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
(
(
ptr
.
wald
,
ptr
.
wald
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
10.4
,
2.8
],
dtype
=
np
.
float64
),
np
.
array
([
10.4
,
2.8
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
4.5
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
4.5
,
2.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -449,11 +447,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
...
@@ -449,11 +447,11 @@ def test_replaced_shared_rng_storage_ordering_equality():
pytest
.
param
(
pytest
.
param
(
ptr
.
vonmises
,
ptr
.
vonmises
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
-
0.5
,
1.3
],
dtype
=
np
.
float64
),
np
.
array
([
-
0.5
,
1.3
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
5.5
,
13.0
],
dtype
=
np
.
float64
),
np
.
array
([
5.5
,
13.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -478,20 +476,16 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
...
@@ -478,20 +476,16 @@ def test_random_RandomVariable(rv_op, dist_params, base_size, cdf_name, params_c
The transpiled `RandomVariable` `Op`.
The transpiled `RandomVariable` `Op`.
dist_params
dist_params
The parameters passed to the op.
The parameters passed to the op.
"""
"""
dist_params
,
test_values
=
(
zip
(
*
dist_params
,
strict
=
True
)
if
dist_params
else
([],
[])
)
rng
=
shared
(
np
.
random
.
default_rng
(
29403
))
rng
=
shared
(
np
.
random
.
default_rng
(
29403
))
g
=
rv_op
(
*
dist_params
,
size
=
(
10000
,
*
base_size
),
rng
=
rng
)
g
=
rv_op
(
*
dist_params
,
size
=
(
10000
,
*
base_size
),
rng
=
rng
)
g_fn
=
compile_random_function
(
dist_params
,
g
,
mode
=
jax_mode
)
g_fn
=
compile_random_function
(
dist_params
,
g
,
mode
=
jax_mode
)
samples
=
g_fn
(
samples
=
g_fn
(
*
test_values
)
*
[
i
.
tag
.
test_value
for
i
in
g_fn
.
maker
.
fgraph
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
]
)
bcast_dist_args
=
np
.
broadcast_arrays
(
*
[
i
.
tag
.
test_value
for
i
in
dist_params
]
)
bcast_dist_args
=
np
.
broadcast_arrays
(
*
test_values
)
for
idx
in
np
.
ndindex
(
*
base_size
):
for
idx
in
np
.
ndindex
(
*
base_size
):
cdf_params
=
params_conv
(
*
(
arg
[
idx
]
for
arg
in
bcast_dist_args
))
cdf_params
=
params_conv
(
*
(
arg
[
idx
]
for
arg
in
bcast_dist_args
))
...
@@ -775,13 +769,12 @@ def test_random_unimplemented():
...
@@ -775,13 +769,12 @@ def test_random_unimplemented():
nonexistentrv
=
NonExistentRV
()
nonexistentrv
=
NonExistentRV
()
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
out
=
nonexistentrv
(
rng
=
rng
)
out
=
nonexistentrv
(
rng
=
rng
)
fgraph
=
FunctionGraph
([
out
.
owner
.
inputs
[
0
]],
[
out
],
clone
=
False
)
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
warns
(
with
pytest
.
warns
(
UserWarning
,
match
=
r"The RandomType SharedVariables \[.+\] will not be used"
UserWarning
,
match
=
r"The RandomType SharedVariables \[.+\] will not be used"
):
):
compare_jax_and_py
(
fgraph
,
[])
compare_jax_and_py
(
[],
[
out
]
,
[])
def
test_random_custom_implementation
():
def
test_random_custom_implementation
():
...
@@ -810,11 +803,10 @@ def test_random_custom_implementation():
...
@@ -810,11 +803,10 @@ def test_random_custom_implementation():
nonexistentrv
=
CustomRV
()
nonexistentrv
=
CustomRV
()
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
rng
=
shared
(
np
.
random
.
default_rng
(
123
))
out
=
nonexistentrv
(
rng
=
rng
)
out
=
nonexistentrv
(
rng
=
rng
)
fgraph
=
FunctionGraph
([
out
.
owner
.
inputs
[
0
]],
[
out
],
clone
=
False
)
with
pytest
.
warns
(
with
pytest
.
warns
(
UserWarning
,
match
=
r"The RandomType SharedVariables \[.+\] will not be used"
UserWarning
,
match
=
r"The RandomType SharedVariables \[.+\] will not be used"
):
):
compare_jax_and_py
(
fgraph
,
[])
compare_jax_and_py
(
[],
[
out
]
,
[])
def
test_random_concrete_shape
():
def
test_random_concrete_shape
():
...
...
tests/link/jax/test_scalar.py
浏览文件 @
cc8c4992
...
@@ -5,7 +5,6 @@ import pytensor.scalar.basic as ps
...
@@ -5,7 +5,6 @@ import pytensor.scalar.basic as ps
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.scalar.basic
import
Composite
from
pytensor.scalar.basic
import
Composite
from
pytensor.tensor
import
as_tensor
from
pytensor.tensor
import
as_tensor
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.elemwise
import
Elemwise
...
@@ -51,20 +50,19 @@ def test_second():
...
@@ -51,20 +50,19 @@ def test_second():
b
=
scalar
(
"b"
)
b
=
scalar
(
"b"
)
out
=
ps
.
second
(
a0
,
b
)
out
=
ps
.
second
(
a0
,
b
)
fgraph
=
FunctionGraph
([
a0
,
b
],
[
out
])
compare_jax_and_py
([
a0
,
b
],
[
out
],
[
10.0
,
5.0
])
compare_jax_and_py
(
fgraph
,
[
10.0
,
5.0
])
a1
=
vector
(
"a1"
)
a1
=
vector
(
"a1"
)
out
=
pt
.
second
(
a1
,
b
)
out
=
pt
.
second
(
a1
,
b
)
fgraph
=
FunctionGraph
([
a1
,
b
],
[
out
])
compare_jax_and_py
([
a1
,
b
],
[
out
],
[
np
.
zeros
([
5
],
dtype
=
config
.
floatX
),
5.0
])
compare_jax_and_py
(
fgraph
,
[
np
.
zeros
([
5
],
dtype
=
config
.
floatX
),
5.0
])
a2
=
matrix
(
"a2"
,
shape
=
(
1
,
None
),
dtype
=
"float64"
)
a2
=
matrix
(
"a2"
,
shape
=
(
1
,
None
),
dtype
=
"float64"
)
b2
=
matrix
(
"b2"
,
shape
=
(
None
,
1
),
dtype
=
"int32"
)
b2
=
matrix
(
"b2"
,
shape
=
(
None
,
1
),
dtype
=
"int32"
)
out
=
pt
.
second
(
a2
,
b2
)
out
=
pt
.
second
(
a2
,
b2
)
fgraph
=
FunctionGraph
([
a2
,
b2
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
fgraph
,
[
np
.
zeros
((
1
,
3
),
dtype
=
"float64"
),
np
.
ones
((
5
,
1
),
dtype
=
"int32"
)]
[
a2
,
b2
],
[
out
],
[
np
.
zeros
((
1
,
3
),
dtype
=
"float64"
),
np
.
ones
((
5
,
1
),
dtype
=
"int32"
)],
)
)
...
@@ -81,11 +79,10 @@ def test_second_constant_scalar():
...
@@ -81,11 +79,10 @@ def test_second_constant_scalar():
def
test_identity
():
def
test_identity
():
a
=
scalar
(
"a"
)
a
=
scalar
(
"a"
)
a
.
tag
.
test_value
=
10
a
_
test_value
=
10
out
=
ps
.
identity
(
a
)
out
=
ps
.
identity
(
a
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -109,13 +106,11 @@ def test_jax_Composite_singe_output(x, y, x_val, y_val):
...
@@ -109,13 +106,11 @@ def test_jax_Composite_singe_output(x, y, x_val, y_val):
out
=
comp_op
(
x
,
y
)
out
=
comp_op
(
x
,
y
)
out_fg
=
FunctionGraph
([
x
,
y
],
[
out
])
test_input_vals
=
[
test_input_vals
=
[
x_val
.
astype
(
config
.
floatX
),
x_val
.
astype
(
config
.
floatX
),
y_val
.
astype
(
config
.
floatX
),
y_val
.
astype
(
config
.
floatX
),
]
]
_
=
compare_jax_and_py
(
out_fg
,
test_input_vals
)
_
=
compare_jax_and_py
(
[
x
,
y
],
[
out
]
,
test_input_vals
)
def
test_jax_Composite_multi_output
():
def
test_jax_Composite_multi_output
():
...
@@ -124,32 +119,28 @@ def test_jax_Composite_multi_output():
...
@@ -124,32 +119,28 @@ def test_jax_Composite_multi_output():
x_s
=
ps
.
float64
(
"xs"
)
x_s
=
ps
.
float64
(
"xs"
)
outs
=
Elemwise
(
Composite
(
inputs
=
[
x_s
],
outputs
=
[
x_s
+
1
,
x_s
-
1
]))(
x
)
outs
=
Elemwise
(
Composite
(
inputs
=
[
x_s
],
outputs
=
[
x_s
+
1
,
x_s
-
1
]))(
x
)
fgraph
=
FunctionGraph
([
x
],
outs
)
compare_jax_and_py
([
x
],
outs
,
[
np
.
arange
(
10
,
dtype
=
config
.
floatX
)])
compare_jax_and_py
(
fgraph
,
[
np
.
arange
(
10
,
dtype
=
config
.
floatX
)])
def
test_erf
():
def
test_erf
():
x
=
scalar
(
"x"
)
x
=
scalar
(
"x"
)
out
=
erf
(
x
)
out
=
erf
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
1.0
])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
1.0
])
def
test_erfc
():
def
test_erfc
():
x
=
scalar
(
"x"
)
x
=
scalar
(
"x"
)
out
=
erfc
(
x
)
out
=
erfc
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
1.0
])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
1.0
])
def
test_erfinv
():
def
test_erfinv
():
x
=
scalar
(
"x"
)
x
=
scalar
(
"x"
)
out
=
erfinv
(
x
)
out
=
erfinv
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
0.95
])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
0.95
])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -166,8 +157,7 @@ def test_tfp_ops(op, test_values):
...
@@ -166,8 +157,7 @@ def test_tfp_ops(op, test_values):
inputs
=
[
as_tensor
(
test_value
)
.
type
()
for
test_value
in
test_values
]
inputs
=
[
as_tensor
(
test_value
)
.
type
()
for
test_value
in
test_values
]
output
=
op
(
*
inputs
)
output
=
op
(
*
inputs
)
fg
=
FunctionGraph
(
inputs
,
[
output
])
compare_jax_and_py
(
inputs
,
[
output
],
test_values
)
compare_jax_and_py
(
fg
,
test_values
)
def
test_betaincinv
():
def
test_betaincinv
():
...
@@ -175,9 +165,10 @@ def test_betaincinv():
...
@@ -175,9 +165,10 @@ def test_betaincinv():
b
=
vector
(
"b"
,
dtype
=
"float64"
)
b
=
vector
(
"b"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
out
=
betaincinv
(
a
,
b
,
x
)
out
=
betaincinv
(
a
,
b
,
x
)
fg
=
FunctionGraph
([
a
,
b
,
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
fg
,
[
a
,
b
,
x
],
[
out
],
[
[
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
5.5
,
7.0
]),
...
@@ -190,39 +181,40 @@ def test_gammaincinv():
...
@@ -190,39 +181,40 @@ def test_gammaincinv():
k
=
vector
(
"k"
,
dtype
=
"float64"
)
k
=
vector
(
"k"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
out
=
gammaincinv
(
k
,
x
)
out
=
gammaincinv
(
k
,
x
)
fg
=
FunctionGraph
([
k
,
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
0.25
,
0.7
])])
compare_jax_and_py
(
[
k
,
x
],
[
out
]
,
[
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
0.25
,
0.7
])])
def
test_gammainccinv
():
def
test_gammainccinv
():
k
=
vector
(
"k"
,
dtype
=
"float64"
)
k
=
vector
(
"k"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
out
=
gammainccinv
(
k
,
x
)
out
=
gammainccinv
(
k
,
x
)
fg
=
FunctionGraph
([
k
,
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
0.25
,
0.7
])])
compare_jax_and_py
(
[
k
,
x
],
[
out
]
,
[
np
.
array
([
5.5
,
7.0
]),
np
.
array
([
0.25
,
0.7
])])
def
test_psi
():
def
test_psi
():
x
=
scalar
(
"x"
)
x
=
scalar
(
"x"
)
out
=
psi
(
x
)
out
=
psi
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
3.0
])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
3.0
])
def
test_tri_gamma
():
def
test_tri_gamma
():
x
=
vector
(
"x"
,
dtype
=
"float64"
)
x
=
vector
(
"x"
,
dtype
=
"float64"
)
out
=
tri_gamma
(
x
)
out
=
tri_gamma
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[
np
.
array
([
3.0
,
5.0
])])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
np
.
array
([
3.0
,
5.0
])])
def
test_polygamma
():
def
test_polygamma
():
n
=
vector
(
"n"
,
dtype
=
"int32"
)
n
=
vector
(
"n"
,
dtype
=
"int32"
)
x
=
vector
(
"x"
,
dtype
=
"float32"
)
x
=
vector
(
"x"
,
dtype
=
"float32"
)
out
=
polygamma
(
n
,
x
)
out
=
polygamma
(
n
,
x
)
fg
=
FunctionGraph
([
n
,
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
fg
,
[
n
,
x
],
[
out
],
[
[
np
.
array
([
0
,
1
,
2
])
.
astype
(
"int32"
),
np
.
array
([
0
,
1
,
2
])
.
astype
(
"int32"
),
np
.
array
([
0.5
,
0.9
,
2.5
])
.
astype
(
"float32"
),
np
.
array
([
0.5
,
0.9
,
2.5
])
.
astype
(
"float32"
),
...
@@ -233,41 +225,34 @@ def test_polygamma():
...
@@ -233,41 +225,34 @@ def test_polygamma():
def
test_log1mexp
():
def
test_log1mexp
():
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
out
=
log1mexp
(
x
)
out
=
log1mexp
(
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
fg
,
[[
-
1.0
,
-
0.75
,
-
0.5
,
-
0.25
]])
compare_jax_and_py
(
[
x
],
[
out
]
,
[[
-
1.0
,
-
0.75
,
-
0.5
,
-
0.25
]])
def
test_nnet
():
def
test_nnet
():
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
x
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
x
_
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
out
=
sigmoid
(
x
)
out
=
sigmoid
(
x
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
x_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
out
=
softplus
(
x
)
out
=
softplus
(
x
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
x_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_jax_variadic_Scalar
():
def
test_jax_variadic_Scalar
():
mu
=
vector
(
"mu"
,
dtype
=
config
.
floatX
)
mu
=
vector
(
"mu"
,
dtype
=
config
.
floatX
)
mu
.
tag
.
test_value
=
np
.
r_
[
0.1
,
1.1
]
.
astype
(
config
.
floatX
)
mu
_
test_value
=
np
.
r_
[
0.1
,
1.1
]
.
astype
(
config
.
floatX
)
tau
=
vector
(
"tau"
,
dtype
=
config
.
floatX
)
tau
=
vector
(
"tau"
,
dtype
=
config
.
floatX
)
tau
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
tau
_
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
res
=
-
tau
*
mu
res
=
-
tau
*
mu
fgraph
=
FunctionGraph
([
mu
,
tau
],
[
res
])
compare_jax_and_py
([
mu
,
tau
],
[
res
],
[
mu_test_value
,
tau_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
res
=
-
tau
*
(
tau
-
mu
)
**
2
res
=
-
tau
*
(
tau
-
mu
)
**
2
fgraph
=
FunctionGraph
([
mu
,
tau
],
[
res
])
compare_jax_and_py
([
mu
,
tau
],
[
res
],
[
mu_test_value
,
tau_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_add_scalars
():
def
test_add_scalars
():
...
@@ -275,8 +260,7 @@ def test_add_scalars():
...
@@ -275,8 +260,7 @@ def test_add_scalars():
size
=
x
.
shape
[
0
]
+
x
.
shape
[
0
]
+
x
.
shape
[
1
]
size
=
x
.
shape
[
0
]
+
x
.
shape
[
0
]
+
x
.
shape
[
1
]
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
ones
((
2
,
3
))
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
out_fg
,
[
np
.
ones
((
2
,
3
))
.
astype
(
config
.
floatX
)])
def
test_mul_scalars
():
def
test_mul_scalars
():
...
@@ -284,8 +268,7 @@ def test_mul_scalars():
...
@@ -284,8 +268,7 @@ def test_mul_scalars():
size
=
x
.
shape
[
0
]
*
x
.
shape
[
0
]
*
x
.
shape
[
1
]
size
=
x
.
shape
[
0
]
*
x
.
shape
[
0
]
*
x
.
shape
[
1
]
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
ones
((
2
,
3
))
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
out_fg
,
[
np
.
ones
((
2
,
3
))
.
astype
(
config
.
floatX
)])
def
test_div_scalars
():
def
test_div_scalars
():
...
@@ -293,8 +276,7 @@ def test_div_scalars():
...
@@ -293,8 +276,7 @@ def test_div_scalars():
size
=
x
.
shape
[
0
]
//
x
.
shape
[
1
]
size
=
x
.
shape
[
0
]
//
x
.
shape
[
1
]
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
ones
((
12
,
3
))
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
out_fg
,
[
np
.
ones
((
12
,
3
))
.
astype
(
config
.
floatX
)])
def
test_mod_scalars
():
def
test_mod_scalars
():
...
@@ -302,39 +284,43 @@ def test_mod_scalars():
...
@@ -302,39 +284,43 @@ def test_mod_scalars():
size
=
x
.
shape
[
0
]
%
x
.
shape
[
1
]
size
=
x
.
shape
[
0
]
%
x
.
shape
[
1
]
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out
=
pt
.
ones
(
size
)
.
astype
(
config
.
floatX
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
ones
((
12
,
3
))
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
out_fg
,
[
np
.
ones
((
12
,
3
))
.
astype
(
config
.
floatX
)])
def
test_jax_multioutput
():
def
test_jax_multioutput
():
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
x
.
tag
.
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
x
_
test_value
=
np
.
r_
[
1.0
,
2.0
]
.
astype
(
config
.
floatX
)
y
=
vector
(
"y"
)
y
=
vector
(
"y"
)
y
.
tag
.
test_value
=
np
.
r_
[
3.0
,
4.0
]
.
astype
(
config
.
floatX
)
y
_
test_value
=
np
.
r_
[
3.0
,
4.0
]
.
astype
(
config
.
floatX
)
w
=
cosh
(
x
**
2
+
y
/
3.0
)
w
=
cosh
(
x
**
2
+
y
/
3.0
)
v
=
cosh
(
x
/
3.0
+
y
**
2
)
v
=
cosh
(
x
/
3.0
+
y
**
2
)
fgraph
=
FunctionGraph
([
x
,
y
],
[
w
,
v
])
compare_jax_and_py
([
x
,
y
],
[
w
,
v
],
[
x_test_value
,
y_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_jax_logp
():
def
test_jax_logp
():
mu
=
vector
(
"mu"
)
mu
=
vector
(
"mu"
)
mu
.
tag
.
test_value
=
np
.
r_
[
0.0
,
0.0
]
.
astype
(
config
.
floatX
)
mu
_
test_value
=
np
.
r_
[
0.0
,
0.0
]
.
astype
(
config
.
floatX
)
tau
=
vector
(
"tau"
)
tau
=
vector
(
"tau"
)
tau
.
tag
.
test_value
=
np
.
r_
[
1.0
,
1.0
]
.
astype
(
config
.
floatX
)
tau
_
test_value
=
np
.
r_
[
1.0
,
1.0
]
.
astype
(
config
.
floatX
)
sigma
=
vector
(
"sigma"
)
sigma
=
vector
(
"sigma"
)
sigma
.
tag
.
test_value
=
(
1.0
/
get_test_value
(
tau
)
)
.
astype
(
config
.
floatX
)
sigma
_test_value
=
(
1.0
/
tau_test_value
)
.
astype
(
config
.
floatX
)
value
=
vector
(
"value"
)
value
=
vector
(
"value"
)
value
.
tag
.
test_value
=
np
.
r_
[
0.1
,
-
10
]
.
astype
(
config
.
floatX
)
value
_
test_value
=
np
.
r_
[
0.1
,
-
10
]
.
astype
(
config
.
floatX
)
logp
=
(
-
tau
*
(
value
-
mu
)
**
2
+
log
(
tau
/
np
.
pi
/
2.0
))
/
2.0
logp
=
(
-
tau
*
(
value
-
mu
)
**
2
+
log
(
tau
/
np
.
pi
/
2.0
))
/
2.0
conditions
=
[
sigma
>
0
]
conditions
=
[
sigma
>
0
]
alltrue
=
pt_all
([
pt_all
(
1
*
val
)
for
val
in
conditions
])
alltrue
=
pt_all
([
pt_all
(
1
*
val
)
for
val
in
conditions
])
normal_logp
=
pt
.
switch
(
alltrue
,
logp
,
-
np
.
inf
)
normal_logp
=
pt
.
switch
(
alltrue
,
logp
,
-
np
.
inf
)
fgraph
=
FunctionGraph
([
mu
,
tau
,
sigma
,
value
],
[
normal_logp
])
compare_jax_and_py
(
[
mu
,
tau
,
sigma
,
value
],
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
[
normal_logp
],
[
mu_test_value
,
tau_test_value
,
sigma_test_value
,
value_test_value
,
],
)
tests/link/jax/test_scan.py
浏览文件 @
cc8c4992
...
@@ -7,7 +7,6 @@ import pytensor.tensor as pt
...
@@ -7,7 +7,6 @@ import pytensor.tensor as pt
from
pytensor
import
function
,
shared
from
pytensor
import
function
,
shared
from
pytensor.compile
import
get_mode
from
pytensor.compile
import
get_mode
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.scan
import
until
from
pytensor.scan
import
until
from
pytensor.scan.basic
import
scan
from
pytensor.scan.basic
import
scan
from
pytensor.scan.op
import
Scan
from
pytensor.scan.op
import
Scan
...
@@ -30,9 +29,8 @@ def test_scan_sit_sot(view):
...
@@ -30,9 +29,8 @@ def test_scan_sit_sot(view):
)
)
if
view
:
if
view
:
xs
=
xs
[
view
]
xs
=
xs
[
view
]
fg
=
FunctionGraph
([
x0
],
[
xs
])
test_input_vals
=
[
np
.
e
]
test_input_vals
=
[
np
.
e
]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
],
[
xs
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
...
@@ -45,9 +43,8 @@ def test_scan_mit_sot(view):
...
@@ -45,9 +43,8 @@ def test_scan_mit_sot(view):
)
)
if
view
:
if
view
:
xs
=
xs
[
view
]
xs
=
xs
[
view
]
fg
=
FunctionGraph
([
x0
],
[
xs
])
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
)]
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
)]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
],
[
xs
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
@pytest.mark.parametrize
(
"view_x"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
@pytest.mark.parametrize
(
"view_x"
,
[
None
,
(
-
1
,),
slice
(
-
4
,
-
1
,
None
)])
...
@@ -72,9 +69,8 @@ def test_scan_multiple_mit_sot(view_x, view_y):
...
@@ -72,9 +69,8 @@ def test_scan_multiple_mit_sot(view_x, view_y):
if
view_y
:
if
view_y
:
ys
=
ys
[
view_y
]
ys
=
ys
[
view_y
]
fg
=
FunctionGraph
([
x0
,
y0
],
[
xs
,
ys
])
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
),
np
.
full
((
4
,),
np
.
pi
)]
test_input_vals
=
[
np
.
full
((
3
,),
np
.
e
),
np
.
full
((
4
,),
np
.
pi
)]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
,
y0
],
[
xs
,
ys
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
2
,),
slice
(
None
,
None
,
2
)])
@pytest.mark.parametrize
(
"view"
,
[
None
,
(
-
2
,),
slice
(
None
,
None
,
2
)])
...
@@ -90,12 +86,11 @@ def test_scan_nit_sot(view):
...
@@ -90,12 +86,11 @@ def test_scan_nit_sot(view):
)
)
if
view
:
if
view
:
ys
=
ys
[
view
]
ys
=
ys
[
view
]
fg
=
FunctionGraph
([
xs
],
[
ys
])
test_input_vals
=
[
rng
.
normal
(
size
=
10
)]
test_input_vals
=
[
rng
.
normal
(
size
=
10
)]
# We need to remove pushout rewrites, or the whole scan would just be
# We need to remove pushout rewrites, or the whole scan would just be
# converted to an Elemwise on xs
# converted to an Elemwise on xs
jax_fn
,
_
=
compare_jax_and_py
(
jax_fn
,
_
=
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan_pushout"
)
[
xs
],
[
ys
]
,
test_input_vals
,
jax_mode
=
get_mode
(
"JAX"
)
.
excluding
(
"scan_pushout"
)
)
)
scan_nodes
=
[
scan_nodes
=
[
node
for
node
in
jax_fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
node
for
node
in
jax_fn
.
maker
.
fgraph
.
apply_nodes
if
isinstance
(
node
.
op
,
Scan
)
...
@@ -112,8 +107,7 @@ def test_scan_mit_mot():
...
@@ -112,8 +107,7 @@ def test_scan_mit_mot():
n_steps
=
10
,
n_steps
=
10
,
)
)
grads_wrt_xs
=
pt
.
grad
(
ys
.
sum
(),
wrt
=
xs
)
grads_wrt_xs
=
pt
.
grad
(
ys
.
sum
(),
wrt
=
xs
)
fg
=
FunctionGraph
([
xs
],
[
grads_wrt_xs
])
compare_jax_and_py
([
xs
],
[
grads_wrt_xs
],
[
np
.
arange
(
10
)])
compare_jax_and_py
(
fg
,
[
np
.
arange
(
10
)])
def
test_scan_update
():
def
test_scan_update
():
...
@@ -192,8 +186,7 @@ def test_scan_while():
...
@@ -192,8 +186,7 @@ def test_scan_while():
n_steps
=
100
,
n_steps
=
100
,
)
)
fg
=
FunctionGraph
([],
[
xs
])
compare_jax_and_py
([],
[
xs
],
[])
compare_jax_and_py
(
fg
,
[])
def
test_scan_SEIR
():
def
test_scan_SEIR
():
...
@@ -257,11 +250,6 @@ def test_scan_SEIR():
...
@@ -257,11 +250,6 @@ def test_scan_SEIR():
logp_c_all
.
name
=
"C_t_logp"
logp_c_all
.
name
=
"C_t_logp"
logp_d_all
.
name
=
"D_t_logp"
logp_d_all
.
name
=
"D_t_logp"
out_fg
=
FunctionGraph
(
[
at_C
,
at_D
,
st0
,
et0
,
it0
,
logp_c
,
logp_d
,
beta
,
gamma
,
delta
],
[
st
,
et
,
it
,
logp_c_all
,
logp_d_all
],
)
s0
,
e0
,
i0
=
100
,
50
,
25
s0
,
e0
,
i0
=
100
,
50
,
25
logp_c0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
logp_c0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
logp_d0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
logp_d0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
...
@@ -283,7 +271,12 @@ def test_scan_SEIR():
...
@@ -283,7 +271,12 @@ def test_scan_SEIR():
gamma_val
,
gamma_val
,
delta_val
,
delta_val
,
]
]
compare_jax_and_py
(
out_fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
at_C
,
at_D
,
st0
,
et0
,
it0
,
logp_c
,
logp_d
,
beta
,
gamma
,
delta
],
[
st
,
et
,
it
,
logp_c_all
,
logp_d_all
],
test_input_vals
,
jax_mode
=
"JAX"
,
)
def
test_scan_mitsot_with_nonseq
():
def
test_scan_mitsot_with_nonseq
():
...
@@ -313,10 +306,8 @@ def test_scan_mitsot_with_nonseq():
...
@@ -313,10 +306,8 @@ def test_scan_mitsot_with_nonseq():
y_scan_pt
.
name
=
"y"
y_scan_pt
.
name
=
"y"
y_scan_pt
.
owner
.
inputs
[
0
]
.
name
=
"y_all"
y_scan_pt
.
owner
.
inputs
[
0
]
.
name
=
"y_all"
out_fg
=
FunctionGraph
([
a_pt
],
[
y_scan_pt
])
test_input_vals
=
[
np
.
array
(
10.0
)
.
astype
(
config
.
floatX
)]
test_input_vals
=
[
np
.
array
(
10.0
)
.
astype
(
config
.
floatX
)]
compare_jax_and_py
(
out_fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
a_pt
],
[
y_scan_pt
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
@pytest.mark.parametrize
(
"x0_func"
,
[
dvector
,
dmatrix
])
@pytest.mark.parametrize
(
"x0_func"
,
[
dvector
,
dmatrix
])
...
@@ -343,9 +334,8 @@ def test_nd_scan_sit_sot(x0_func, A_func):
...
@@ -343,9 +334,8 @@ def test_nd_scan_sit_sot(x0_func, A_func):
)
)
A_val
=
np
.
eye
(
k
,
dtype
=
config
.
floatX
)
A_val
=
np
.
eye
(
k
,
dtype
=
config
.
floatX
)
fg
=
FunctionGraph
([
x0
,
A
],
[
xs
])
test_input_vals
=
[
x0_val
,
A_val
]
test_input_vals
=
[
x0_val
,
A_val
]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
,
A
],
[
xs
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
def
test_nd_scan_sit_sot_with_seq
():
def
test_nd_scan_sit_sot_with_seq
():
...
@@ -366,9 +356,8 @@ def test_nd_scan_sit_sot_with_seq():
...
@@ -366,9 +356,8 @@ def test_nd_scan_sit_sot_with_seq():
x_val
=
np
.
arange
(
n_steps
*
k
,
dtype
=
config
.
floatX
)
.
reshape
(
n_steps
,
k
)
x_val
=
np
.
arange
(
n_steps
*
k
,
dtype
=
config
.
floatX
)
.
reshape
(
n_steps
,
k
)
A_val
=
np
.
eye
(
k
,
dtype
=
config
.
floatX
)
A_val
=
np
.
eye
(
k
,
dtype
=
config
.
floatX
)
fg
=
FunctionGraph
([
x
,
A
],
[
xs
])
test_input_vals
=
[
x_val
,
A_val
]
test_input_vals
=
[
x_val
,
A_val
]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x
,
A
],
[
xs
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
def
test_nd_scan_mit_sot
():
def
test_nd_scan_mit_sot
():
...
@@ -384,13 +373,12 @@ def test_nd_scan_mit_sot():
...
@@ -384,13 +373,12 @@ def test_nd_scan_mit_sot():
n_steps
=
10
,
n_steps
=
10
,
)
)
fg
=
FunctionGraph
([
x0
,
A
,
B
],
[
xs
])
x0_val
=
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
)
x0_val
=
np
.
arange
(
9
,
dtype
=
config
.
floatX
)
.
reshape
(
3
,
3
)
A_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
A_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
B_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
B_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
test_input_vals
=
[
x0_val
,
A_val
,
B_val
]
test_input_vals
=
[
x0_val
,
A_val
,
B_val
]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
,
A
,
B
],
[
xs
]
,
test_input_vals
,
jax_mode
=
"JAX"
)
def
test_nd_scan_sit_sot_with_carry
():
def
test_nd_scan_sit_sot_with_carry
():
...
@@ -409,12 +397,11 @@ def test_nd_scan_sit_sot_with_carry():
...
@@ -409,12 +397,11 @@ def test_nd_scan_sit_sot_with_carry():
mode
=
get_mode
(
"JAX"
),
mode
=
get_mode
(
"JAX"
),
)
)
fg
=
FunctionGraph
([
x0
,
A
],
xs
)
x0_val
=
np
.
arange
(
3
,
dtype
=
config
.
floatX
)
x0_val
=
np
.
arange
(
3
,
dtype
=
config
.
floatX
)
A_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
A_val
=
np
.
eye
(
3
,
dtype
=
config
.
floatX
)
test_input_vals
=
[
x0_val
,
A_val
]
test_input_vals
=
[
x0_val
,
A_val
]
compare_jax_and_py
(
fg
,
test_input_vals
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
[
x0
,
A
],
xs
,
test_input_vals
,
jax_mode
=
"JAX"
)
def
test_default_mode_excludes_incompatible_rewrites
():
def
test_default_mode_excludes_incompatible_rewrites
():
...
@@ -422,8 +409,7 @@ def test_default_mode_excludes_incompatible_rewrites():
...
@@ -422,8 +409,7 @@ def test_default_mode_excludes_incompatible_rewrites():
A
=
matrix
(
"A"
)
A
=
matrix
(
"A"
)
B
=
matrix
(
"B"
)
B
=
matrix
(
"B"
)
out
,
_
=
scan
(
lambda
a
,
b
:
a
@
b
,
outputs_info
=
[
A
],
non_sequences
=
[
B
],
n_steps
=
2
)
out
,
_
=
scan
(
lambda
a
,
b
:
a
@
b
,
outputs_info
=
[
A
],
non_sequences
=
[
B
],
n_steps
=
2
)
fg
=
FunctionGraph
([
A
,
B
],
[
out
])
compare_jax_and_py
([
A
,
B
],
[
out
],
[
np
.
eye
(
3
),
np
.
eye
(
3
)],
jax_mode
=
"JAX"
)
compare_jax_and_py
(
fg
,
[
np
.
eye
(
3
),
np
.
eye
(
3
)],
jax_mode
=
"JAX"
)
def
test_dynamic_sequence_length
():
def
test_dynamic_sequence_length
():
...
...
tests/link/jax/test_shape.py
浏览文件 @
cc8c4992
...
@@ -4,7 +4,6 @@ import pytest
...
@@ -4,7 +4,6 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.compile.ops
import
DeepCopyOp
,
ViewOp
from
pytensor.compile.ops
import
DeepCopyOp
,
ViewOp
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor.shape
import
Shape
,
Shape_i
,
Unbroadcast
,
reshape
from
pytensor.tensor.shape
import
Shape
,
Shape_i
,
Unbroadcast
,
reshape
from
pytensor.tensor.type
import
iscalar
,
vector
from
pytensor.tensor.type
import
iscalar
,
vector
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -13,29 +12,27 @@ from tests.link.jax.test_basic import compare_jax_and_py
...
@@ -13,29 +12,27 @@ from tests.link.jax.test_basic import compare_jax_and_py
def
test_jax_shape_ops
():
def
test_jax_shape_ops
():
x_np
=
np
.
zeros
((
20
,
3
))
x_np
=
np
.
zeros
((
20
,
3
))
x
=
Shape
()(
pt
.
as_tensor_variable
(
x_np
))
x
=
Shape
()(
pt
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[],
must_be_device_array
=
False
)
compare_jax_and_py
(
[],
[
x
]
,
[],
must_be_device_array
=
False
)
x
=
Shape_i
(
1
)(
pt
.
as_tensor_variable
(
x_np
))
x
=
Shape_i
(
1
)(
pt
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[],
must_be_device_array
=
False
)
compare_jax_and_py
(
[],
[
x
]
,
[],
must_be_device_array
=
False
)
def
test_jax_specify_shape
():
def
test_jax_specify_shape
():
in_pt
=
pt
.
matrix
(
"in"
)
in_pt
=
pt
.
matrix
(
"in"
)
x
=
pt
.
specify_shape
(
in_pt
,
(
4
,
None
))
x
=
pt
.
specify_shape
(
in_pt
,
(
4
,
None
))
x_fg
=
FunctionGraph
([
in_pt
],
[
x
])
compare_jax_and_py
([
in_pt
],
[
x
],
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)])
# When used to assert two arrays have similar shapes
# When used to assert two arrays have similar shapes
in_pt
=
pt
.
matrix
(
"in"
)
in_pt
=
pt
.
matrix
(
"in"
)
shape_pt
=
pt
.
matrix
(
"shape"
)
shape_pt
=
pt
.
matrix
(
"shape"
)
x
=
pt
.
specify_shape
(
in_pt
,
shape_pt
.
shape
)
x
=
pt
.
specify_shape
(
in_pt
,
shape_pt
.
shape
)
x_fg
=
FunctionGraph
([
in_pt
,
shape_pt
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
in_pt
,
shape_pt
],
[
x
],
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
),
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)],
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
),
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)],
)
)
...
@@ -43,20 +40,17 @@ def test_jax_specify_shape():
...
@@ -43,20 +40,17 @@ def test_jax_specify_shape():
def
test_jax_Reshape_constant
():
def
test_jax_Reshape_constant
():
a
=
vector
(
"a"
)
a
=
vector
(
"a"
)
x
=
reshape
(
a
,
(
2
,
2
))
x
=
reshape
(
a
,
(
2
,
2
))
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
([
a
],
[
x
],
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
def
test_jax_Reshape_concrete_shape
():
def
test_jax_Reshape_concrete_shape
():
"""JAX should compile when a concrete value is passed for the `shape` parameter."""
"""JAX should compile when a concrete value is passed for the `shape` parameter."""
a
=
vector
(
"a"
)
a
=
vector
(
"a"
)
x
=
reshape
(
a
,
a
.
shape
)
x
=
reshape
(
a
,
a
.
shape
)
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
([
a
],
[
x
],
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
x
=
reshape
(
a
,
(
a
.
shape
[
0
]
//
2
,
a
.
shape
[
0
]
//
2
))
x
=
reshape
(
a
,
(
a
.
shape
[
0
]
//
2
,
a
.
shape
[
0
]
//
2
))
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
([
a
],
[
x
],
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
@pytest.mark.xfail
(
@pytest.mark.xfail
(
...
@@ -66,23 +60,20 @@ def test_jax_Reshape_shape_graph_input():
...
@@ -66,23 +60,20 @@ def test_jax_Reshape_shape_graph_input():
a
=
vector
(
"a"
)
a
=
vector
(
"a"
)
shape_pt
=
iscalar
(
"b"
)
shape_pt
=
iscalar
(
"b"
)
x
=
reshape
(
a
,
(
shape_pt
,
shape_pt
))
x
=
reshape
(
a
,
(
shape_pt
,
shape_pt
))
x_fg
=
FunctionGraph
([
a
,
shape_pt
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
])
[
a
,
shape_pt
],
[
x
],
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
]
)
def
test_jax_compile_ops
():
def
test_jax_compile_ops
():
x
=
DeepCopyOp
()(
pt
.
as_tensor_variable
(
1.1
))
x
=
DeepCopyOp
()(
pt
.
as_tensor_variable
(
1.1
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
([],
[
x
],
[])
compare_jax_and_py
(
x_fg
,
[])
x_np
=
np
.
zeros
((
20
,
1
,
1
))
x_np
=
np
.
zeros
((
20
,
1
,
1
))
x
=
Unbroadcast
(
0
,
2
)(
pt
.
as_tensor_variable
(
x_np
))
x
=
Unbroadcast
(
0
,
2
)(
pt
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
[],
[
x
]
,
[])
x
=
ViewOp
()(
pt
.
as_tensor_variable
(
x_np
))
x
=
ViewOp
()(
pt
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
[],
[
x
]
,
[])
tests/link/jax/test_slinalg.py
浏览文件 @
cc8c4992
...
@@ -6,7 +6,6 @@ import pytest
...
@@ -6,7 +6,6 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor
import
slinalg
as
pt_slinalg
from
pytensor.tensor
import
slinalg
as
pt_slinalg
from
pytensor.tensor
import
subtensor
as
pt_subtensor
from
pytensor.tensor
import
subtensor
as
pt_subtensor
...
@@ -30,13 +29,11 @@ def test_jax_basic():
...
@@ -30,13 +29,11 @@ def test_jax_basic():
out
=
pt_subtensor
.
inc_subtensor
(
out
[
0
,
1
],
2.0
)
out
=
pt_subtensor
.
inc_subtensor
(
out
[
0
,
1
],
2.0
)
out
=
out
[:
5
,
:
3
]
out
=
out
[:
5
,
:
3
]
out_fg
=
FunctionGraph
([
x
,
y
],
[
out
])
test_input_vals
=
[
test_input_vals
=
[
np
.
tile
(
np
.
arange
(
10
),
(
10
,
1
))
.
astype
(
config
.
floatX
),
np
.
tile
(
np
.
arange
(
10
),
(
10
,
1
))
.
astype
(
config
.
floatX
),
np
.
tile
(
np
.
arange
(
10
,
20
),
(
10
,
1
))
.
astype
(
config
.
floatX
),
np
.
tile
(
np
.
arange
(
10
,
20
),
(
10
,
1
))
.
astype
(
config
.
floatX
),
]
]
_
,
[
jax_res
]
=
compare_jax_and_py
(
out_fg
,
test_input_vals
)
_
,
[
jax_res
]
=
compare_jax_and_py
(
[
x
,
y
],
[
out
]
,
test_input_vals
)
# Confirm that the `Subtensor` slice operations are correct
# Confirm that the `Subtensor` slice operations are correct
assert
jax_res
.
shape
==
(
5
,
3
)
assert
jax_res
.
shape
==
(
5
,
3
)
...
@@ -46,19 +43,17 @@ def test_jax_basic():
...
@@ -46,19 +43,17 @@ def test_jax_basic():
assert
jax_res
[
0
,
1
]
==
-
8.0
assert
jax_res
[
0
,
1
]
==
-
8.0
out
=
clip
(
x
,
y
,
5
)
out
=
clip
(
x
,
y
,
5
)
out_fg
=
FunctionGraph
([
x
,
y
],
[
out
])
compare_jax_and_py
([
x
,
y
],
[
out
],
test_input_vals
)
compare_jax_and_py
(
out_fg
,
test_input_vals
)
out
=
pt
.
diagonal
(
x
,
0
)
out
=
pt
.
diagonal
(
x
,
0
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
config
.
floatX
)]
[
x
],
[
out
]
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
config
.
floatX
)]
)
)
out
=
pt_slinalg
.
cholesky
(
x
)
out
=
pt_slinalg
.
cholesky
(
x
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
x
],
[
out
],
[
[
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
config
.
floatX
config
.
floatX
...
@@ -68,9 +63,9 @@ def test_jax_basic():
...
@@ -68,9 +63,9 @@ def test_jax_basic():
# not sure why this isn't working yet with lower=False
# not sure why this isn't working yet with lower=False
out
=
pt_slinalg
.
Cholesky
(
lower
=
False
)(
x
)
out
=
pt_slinalg
.
Cholesky
(
lower
=
False
)(
x
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
x
],
[
out
],
[
[
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
config
.
floatX
config
.
floatX
...
@@ -79,9 +74,9 @@ def test_jax_basic():
...
@@ -79,9 +74,9 @@ def test_jax_basic():
)
)
out
=
pt_slinalg
.
solve
(
x
,
b
)
out
=
pt_slinalg
.
solve
(
x
,
b
)
out_fg
=
FunctionGraph
([
x
,
b
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
x
,
b
],
[
out
],
[
[
np
.
eye
(
10
)
.
astype
(
config
.
floatX
),
np
.
eye
(
10
)
.
astype
(
config
.
floatX
),
np
.
arange
(
10
)
.
astype
(
config
.
floatX
),
np
.
arange
(
10
)
.
astype
(
config
.
floatX
),
...
@@ -89,19 +84,17 @@ def test_jax_basic():
...
@@ -89,19 +84,17 @@ def test_jax_basic():
)
)
out
=
pt
.
diag
(
b
)
out
=
pt
.
diag
(
b
)
out_fg
=
FunctionGraph
([
b
],
[
out
])
compare_jax_and_py
([
b
],
[
out
],
[
np
.
arange
(
10
)
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
10
)
.
astype
(
config
.
floatX
)])
out
=
pt_nlinalg
.
det
(
x
)
out
=
pt_nlinalg
.
det
(
x
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
config
.
floatX
)]
[
x
],
[
out
]
,
[
np
.
arange
(
10
*
10
)
.
reshape
((
10
,
10
))
.
astype
(
config
.
floatX
)]
)
)
out
=
pt_nlinalg
.
matrix_inverse
(
x
)
out
=
pt_nlinalg
.
matrix_inverse
(
x
)
out_fg
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
x
],
[
out
],
[
[
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
(
np
.
eye
(
10
)
+
rng
.
standard_normal
(
size
=
(
10
,
10
))
*
0.01
)
.
astype
(
config
.
floatX
config
.
floatX
...
@@ -124,9 +117,9 @@ def test_jax_SolveTriangular(trans, lower, check_finite):
...
@@ -124,9 +117,9 @@ def test_jax_SolveTriangular(trans, lower, check_finite):
lower
=
lower
,
lower
=
lower
,
check_finite
=
check_finite
,
check_finite
=
check_finite
,
)
)
out_fg
=
FunctionGraph
([
x
,
b
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
x
,
b
],
[
out
],
[
[
np
.
eye
(
10
)
.
astype
(
config
.
floatX
),
np
.
eye
(
10
)
.
astype
(
config
.
floatX
),
np
.
arange
(
10
)
.
astype
(
config
.
floatX
),
np
.
arange
(
10
)
.
astype
(
config
.
floatX
),
...
@@ -141,10 +134,10 @@ def test_jax_block_diag():
...
@@ -141,10 +134,10 @@ def test_jax_block_diag():
D
=
matrix
(
"D"
)
D
=
matrix
(
"D"
)
out
=
pt_slinalg
.
block_diag
(
A
,
B
,
C
,
D
)
out
=
pt_slinalg
.
block_diag
(
A
,
B
,
C
,
D
)
out_fg
=
FunctionGraph
([
A
,
B
,
C
,
D
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
A
,
B
,
C
,
D
],
[
out
],
[
[
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
3
,
3
))
.
astype
(
config
.
floatX
),
...
@@ -158,9 +151,10 @@ def test_jax_block_diag_blockwise():
...
@@ -158,9 +151,10 @@ def test_jax_block_diag_blockwise():
A
=
pt
.
tensor3
(
"A"
)
A
=
pt
.
tensor3
(
"A"
)
B
=
pt
.
tensor3
(
"B"
)
B
=
pt
.
tensor3
(
"B"
)
out
=
pt_slinalg
.
block_diag
(
A
,
B
)
out
=
pt_slinalg
.
block_diag
(
A
,
B
)
out_fg
=
FunctionGraph
([
A
,
B
],
[
out
])
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
A
,
B
],
[
out
],
[
[
np
.
random
.
normal
(
size
=
(
5
,
5
,
5
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
5
,
5
,
5
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
5
,
3
,
3
))
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
(
5
,
3
,
3
))
.
astype
(
config
.
floatX
),
...
@@ -174,11 +168,11 @@ def test_jax_eigvalsh(lower):
...
@@ -174,11 +168,11 @@ def test_jax_eigvalsh(lower):
B
=
matrix
(
"B"
)
B
=
matrix
(
"B"
)
out
=
pt_slinalg
.
eigvalsh
(
A
,
B
,
lower
=
lower
)
out
=
pt_slinalg
.
eigvalsh
(
A
,
B
,
lower
=
lower
)
out_fg
=
FunctionGraph
([
A
,
B
],
[
out
])
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
A
,
B
],
[
out
],
[
[
np
.
array
(
np
.
array
(
[[
6
,
3
,
1
,
5
],
[
3
,
0
,
5
,
1
],
[
1
,
5
,
6
,
2
],
[
5
,
1
,
2
,
2
]]
[[
6
,
3
,
1
,
5
],
[
3
,
0
,
5
,
1
],
[
1
,
5
,
6
,
2
],
[
5
,
1
,
2
,
2
]]
...
@@ -189,7 +183,8 @@ def test_jax_eigvalsh(lower):
...
@@ -189,7 +183,8 @@ def test_jax_eigvalsh(lower):
],
],
)
)
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
A
,
B
],
[
out
],
[
[
np
.
array
([[
6
,
3
,
1
,
5
],
[
3
,
0
,
5
,
1
],
[
1
,
5
,
6
,
2
],
[
5
,
1
,
2
,
2
]])
.
astype
(
np
.
array
([[
6
,
3
,
1
,
5
],
[
3
,
0
,
5
,
1
],
[
1
,
5
,
6
,
2
],
[
5
,
1
,
2
,
2
]])
.
astype
(
config
.
floatX
config
.
floatX
...
@@ -207,11 +202,11 @@ def test_jax_solve_discrete_lyapunov(
...
@@ -207,11 +202,11 @@ def test_jax_solve_discrete_lyapunov(
A
=
pt
.
tensor
(
name
=
"A"
,
shape
=
shape
)
A
=
pt
.
tensor
(
name
=
"A"
,
shape
=
shape
)
B
=
pt
.
tensor
(
name
=
"B"
,
shape
=
shape
)
B
=
pt
.
tensor
(
name
=
"B"
,
shape
=
shape
)
out
=
pt_slinalg
.
solve_discrete_lyapunov
(
A
,
B
,
method
=
method
)
out
=
pt_slinalg
.
solve_discrete_lyapunov
(
A
,
B
,
method
=
method
)
out_fg
=
FunctionGraph
([
A
,
B
],
[
out
])
atol
=
rtol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-3
atol
=
rtol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-3
compare_jax_and_py
(
compare_jax_and_py
(
out_fg
,
[
A
,
B
],
[
out
],
[
[
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
),
np
.
random
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
),
...
...
tests/link/jax/test_sort.py
浏览文件 @
cc8c4992
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
pytensor.graph
import
FunctionGraph
from
pytensor.tensor
import
matrix
from
pytensor.tensor
import
matrix
from
pytensor.tensor.sort
import
argsort
,
sort
from
pytensor.tensor.sort
import
argsort
,
sort
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -12,6 +11,5 @@ from tests.link.jax.test_basic import compare_jax_and_py
...
@@ -12,6 +11,5 @@ from tests.link.jax.test_basic import compare_jax_and_py
def
test_sort
(
func
,
axis
):
def
test_sort
(
func
,
axis
):
x
=
matrix
(
"x"
,
shape
=
(
2
,
2
),
dtype
=
"float64"
)
x
=
matrix
(
"x"
,
shape
=
(
2
,
2
),
dtype
=
"float64"
)
out
=
func
(
x
,
axis
=
axis
)
out
=
func
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
arr
=
np
.
array
([[
1.0
,
4.0
],
[
5.0
,
2.0
]])
arr
=
np
.
array
([[
1.0
,
4.0
],
[
5.0
,
2.0
]])
compare_jax_and_py
(
fgraph
,
[
arr
])
compare_jax_and_py
(
[
x
],
[
out
]
,
[
arr
])
tests/link/jax/test_sparse.py
浏览文件 @
cc8c4992
...
@@ -5,7 +5,6 @@ import scipy.sparse
...
@@ -5,7 +5,6 @@ import scipy.sparse
import
pytensor.sparse
as
ps
import
pytensor.sparse
as
ps
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
function
from
pytensor
import
function
from
pytensor.graph
import
FunctionGraph
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
...
@@ -50,8 +49,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op):
...
@@ -50,8 +49,7 @@ def test_sparse_dot_constant_sparse(x_type, y_type, op):
test_values
.
append
(
y_test
)
test_values
.
append
(
y_test
)
dot_pt
=
op
(
x_pt
,
y_pt
)
dot_pt
=
op
(
x_pt
,
y_pt
)
fgraph
=
FunctionGraph
(
inputs
,
[
dot_pt
])
compare_jax_and_py
(
inputs
,
[
dot_pt
],
test_values
,
jax_mode
=
"JAX"
)
compare_jax_and_py
(
fgraph
,
test_values
,
jax_mode
=
"JAX"
)
def
test_sparse_dot_non_const_raises
():
def
test_sparse_dot_non_const_raises
():
...
...
tests/link/jax/test_subtensor.py
浏览文件 @
cc8c4992
...
@@ -21,55 +21,55 @@ def test_jax_Subtensor_constant():
...
@@ -21,55 +21,55 @@ def test_jax_Subtensor_constant():
# Basic indices
# Basic indices
out_pt
=
x_pt
[
1
,
2
,
0
]
out_pt
=
x_pt
[
1
,
2
,
0
]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
out_pt
=
x_pt
[
1
:,
1
,
:]
out_pt
=
x_pt
[
1
:,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
out_pt
=
x_pt
[:
2
,
1
,
:]
out_pt
=
x_pt
[:
2
,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
out_pt
=
x_pt
[
1
:
2
,
1
,
:]
out_pt
=
x_pt
[
1
:
2
,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
# Advanced indexing
# Advanced indexing
out_pt
=
pt_subtensor
.
advanced_subtensor1
(
x_pt
,
[
1
,
2
])
out_pt
=
pt_subtensor
.
advanced_subtensor1
(
x_pt
,
[
1
,
2
])
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor1
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor1
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
[
2
,
3
]]
out_pt
=
x_pt
[[
1
,
2
],
[
2
,
3
]]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
# Advanced and basic indexing
# Advanced and basic indexing
out_pt
=
x_pt
[[
1
,
2
],
:]
out_pt
=
x_pt
[[
1
,
2
],
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
:,
[
3
,
4
]]
out_pt
=
x_pt
[[
1
,
2
],
:,
[
3
,
4
]]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
# Flipping
# Flipping
out_pt
=
x_pt
[::
-
1
]
out_pt
=
x_pt
[::
-
1
]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
# Boolean indexing should work if indexes are constant
# Boolean indexing should work if indexes are constant
out_pt
=
x_pt
[
np
.
random
.
binomial
(
1
,
0.5
,
size
=
(
3
,
4
,
5
))
.
astype
(
bool
)]
out_pt
=
x_pt
[
np
.
random
.
binomial
(
1
,
0.5
,
size
=
(
3
,
4
,
5
))
.
astype
(
bool
)]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
@pytest.mark.xfail
(
reason
=
"`a` should be specified as static when JIT-compiling"
)
@pytest.mark.xfail
(
reason
=
"`a` should be specified as static when JIT-compiling"
)
...
@@ -78,8 +78,8 @@ def test_jax_Subtensor_dynamic():
...
@@ -78,8 +78,8 @@ def test_jax_Subtensor_dynamic():
x
=
pt
.
arange
(
3
)
x
=
pt
.
arange
(
3
)
out_pt
=
x
[:
a
]
out_pt
=
x
[:
a
]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
a
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
1
])
compare_jax_and_py
(
[
a
],
[
out_pt
]
,
[
1
])
def
test_jax_Subtensor_dynamic_boolean_mask
():
def
test_jax_Subtensor_dynamic_boolean_mask
():
...
@@ -90,11 +90,9 @@ def test_jax_Subtensor_dynamic_boolean_mask():
...
@@ -90,11 +90,9 @@ def test_jax_Subtensor_dynamic_boolean_mask():
out_pt
=
x_pt
[
x_pt
<
0
]
out_pt
=
x_pt
[
x_pt
<
0
]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
x_pt_test
=
np
.
arange
(
-
5
,
5
)
x_pt_test
=
np
.
arange
(
-
5
,
5
)
with
pytest
.
raises
(
NonConcreteBooleanIndexError
):
with
pytest
.
raises
(
NonConcreteBooleanIndexError
):
compare_jax_and_py
(
out_fg
,
[
x_pt_test
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_pt_test
])
def
test_jax_Subtensor_boolean_mask_reexpressible
():
def
test_jax_Subtensor_boolean_mask_reexpressible
():
...
@@ -110,8 +108,10 @@ def test_jax_Subtensor_boolean_mask_reexpressible():
...
@@ -110,8 +108,10 @@ def test_jax_Subtensor_boolean_mask_reexpressible():
"""
"""
x_pt
=
pt
.
matrix
(
"x"
)
x_pt
=
pt
.
matrix
(
"x"
)
out_pt
=
x_pt
[
x_pt
<
0
]
.
sum
()
out_pt
=
x_pt
[
x_pt
<
0
]
.
sum
()
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
np
.
arange
(
25
)
.
reshape
(
5
,
5
)
.
astype
(
config
.
floatX
)])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
],
[
np
.
arange
(
25
)
.
reshape
(
5
,
5
)
.
astype
(
config
.
floatX
)]
)
def
test_boolean_indexing_sum_not_applicable
():
def
test_boolean_indexing_sum_not_applicable
():
...
@@ -136,19 +136,19 @@ def test_jax_IncSubtensor():
...
@@ -136,19 +136,19 @@ def test_jax_IncSubtensor():
st_pt
=
pt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
config
.
floatX
))
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
1
,
2
,
3
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
1
,
2
,
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[:
2
,
0
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[:
2
,
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
0
,
1
:
3
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
0
,
1
:
3
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
# "Set" advanced indices
# "Set" advanced indices
st_pt
=
pt
.
as_tensor_variable
(
st_pt
=
pt
.
as_tensor_variable
(
...
@@ -156,39 +156,39 @@ def test_jax_IncSubtensor():
...
@@ -156,39 +156,39 @@ def test_jax_IncSubtensor():
)
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
np
.
r_
[
0
,
2
]],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
np
.
r_
[
0
,
2
]],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[[
0
,
2
],
0
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[[
0
,
2
],
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
# "Set" boolean indices
# "Set" boolean indices
mask_pt
=
pt
.
constant
(
x_np
>
0
)
mask_pt
=
pt
.
constant
(
x_np
>
0
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
0.0
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
0.0
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
# "Increment" basic indices
# "Increment" basic indices
st_pt
=
pt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
config
.
floatX
))
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
1
,
2
,
3
],
st_pt
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
1
,
2
,
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[:
2
,
0
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[:
2
,
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
0
,
1
:
3
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
0
,
1
:
3
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
# "Increment" advanced indices
# "Increment" advanced indices
st_pt
=
pt
.
as_tensor_variable
(
st_pt
=
pt
.
as_tensor_variable
(
...
@@ -196,33 +196,33 @@ def test_jax_IncSubtensor():
...
@@ -196,33 +196,33 @@ def test_jax_IncSubtensor():
)
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
np
.
r_
[
0
,
2
]],
st_pt
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
np
.
r_
[
0
,
2
]],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
config
.
floatX
))
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[[
0
,
2
],
0
,
0
],
st_pt
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[[
0
,
2
],
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
# "Increment" boolean indices
# "Increment" boolean indices
mask_pt
=
pt
.
constant
(
x_np
>
0
)
mask_pt
=
pt
.
constant
(
x_np
>
0
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
1.0
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
1.0
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
st_pt
=
pt
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[[
0
,
2
],
0
,
:
3
],
st_pt
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[[
0
,
2
],
0
,
:
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
st_pt
=
pt
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
st_pt
=
pt
.
as_tensor_variable
(
x_np
[[
0
,
2
],
0
,
:
3
])
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[[
0
,
2
],
0
,
:
3
],
st_pt
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[[
0
,
2
],
0
,
:
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out_pt
]
,
[])
def
test_jax_IncSubtensor_boolean_indexing_reexpressible
():
def
test_jax_IncSubtensor_boolean_indexing_reexpressible
():
...
@@ -243,14 +243,14 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
...
@@ -243,14 +243,14 @@ def test_jax_IncSubtensor_boolean_indexing_reexpressible():
mask_pt
=
pt
.
as_tensor
(
x_pt
)
>
0
mask_pt
=
pt
.
as_tensor
(
x_pt
)
>
0
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
0.0
)
out_pt
=
pt_subtensor
.
set_subtensor
(
x_pt
[
mask_pt
],
0.0
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
mask_pt
=
pt
.
as_tensor
(
x_pt
)
>
0
mask_pt
=
pt
.
as_tensor
(
x_pt
)
>
0
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
mask_pt
],
1.0
)
out_pt
=
pt_subtensor
.
inc_subtensor
(
x_pt
[
mask_pt
],
1.0
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_jax_and_py
(
out_fg
,
[
x_np
])
compare_jax_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
def
test_boolean_indexing_set_or_inc_not_applicable
():
def
test_boolean_indexing_set_or_inc_not_applicable
():
...
...
tests/link/jax/test_tensor_basic.py
浏览文件 @
cc8c4992
...
@@ -10,8 +10,6 @@ from jax import errors
...
@@ -10,8 +10,6 @@ from jax import errors
import
pytensor
import
pytensor
import
pytensor.tensor.basic
as
ptb
import
pytensor.tensor.basic
as
ptb
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
get_test_value
from
pytensor.tensor.type
import
iscalar
,
matrix
,
scalar
,
vector
from
pytensor.tensor.type
import
iscalar
,
matrix
,
scalar
,
vector
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.link.jax.test_basic
import
compare_jax_and_py
from
tests.tensor.test_basic
import
check_alloc_runtime_broadcast
from
tests.tensor.test_basic
import
check_alloc_runtime_broadcast
...
@@ -19,38 +17,31 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast
...
@@ -19,38 +17,31 @@ from tests.tensor.test_basic import check_alloc_runtime_broadcast
def
test_jax_Alloc
():
def
test_jax_Alloc
():
x
=
ptb
.
alloc
(
0.0
,
2
,
3
)
x
=
ptb
.
alloc
(
0.0
,
2
,
3
)
x_fg
=
FunctionGraph
([],
[
x
])
_
,
[
jax_res
]
=
compare_jax_and_py
(
x_fg
,
[])
_
,
[
jax_res
]
=
compare_jax_and_py
(
[],
[
x
]
,
[])
assert
jax_res
.
shape
==
(
2
,
3
)
assert
jax_res
.
shape
==
(
2
,
3
)
x
=
ptb
.
alloc
(
1.1
,
2
,
3
)
x
=
ptb
.
alloc
(
1.1
,
2
,
3
)
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
[],
[
x
]
,
[])
x
=
ptb
.
AllocEmpty
(
"float32"
)(
2
,
3
)
x
=
ptb
.
AllocEmpty
(
"float32"
)(
2
,
3
)
x_fg
=
FunctionGraph
([],
[
x
])
def
compare_shape_dtype
(
x
,
y
):
def
compare_shape_dtype
(
x
,
y
):
(
x
,)
=
x
np
.
testing
.
assert_array_equal
(
x
,
y
,
strict
=
True
)
(
y
,)
=
y
return
x
.
shape
==
y
.
shape
and
x
.
dtype
==
y
.
dtype
compare_jax_and_py
(
x_fg
,
[],
assert_fn
=
compare_shape_dtype
)
compare_jax_and_py
(
[],
[
x
]
,
[],
assert_fn
=
compare_shape_dtype
)
a
=
scalar
(
"a"
)
a
=
scalar
(
"a"
)
x
=
ptb
.
alloc
(
a
,
20
)
x
=
ptb
.
alloc
(
a
,
20
)
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
10.0
])
compare_jax_and_py
(
[
a
],
[
x
]
,
[
10.0
])
a
=
vector
(
"a"
)
a
=
vector
(
"a"
)
x
=
ptb
.
alloc
(
a
,
20
,
10
)
x
=
ptb
.
alloc
(
a
,
20
,
10
)
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_jax_and_py
(
x_fg
,
[
np
.
ones
(
10
,
dtype
=
config
.
floatX
)])
compare_jax_and_py
(
[
a
],
[
x
]
,
[
np
.
ones
(
10
,
dtype
=
config
.
floatX
)])
def
test_alloc_runtime_broadcast
():
def
test_alloc_runtime_broadcast
():
...
@@ -59,34 +50,31 @@ def test_alloc_runtime_broadcast():
...
@@ -59,34 +50,31 @@ def test_alloc_runtime_broadcast():
def
test_jax_MakeVector
():
def
test_jax_MakeVector
():
x
=
ptb
.
make_vector
(
1
,
2
,
3
)
x
=
ptb
.
make_vector
(
1
,
2
,
3
)
x_fg
=
FunctionGraph
([],
[
x
])
compare_jax_and_py
(
x_fg
,
[])
compare_jax_and_py
(
[],
[
x
]
,
[])
def
test_arange
():
def
test_arange
():
out
=
ptb
.
arange
(
1
,
10
,
2
)
out
=
ptb
.
arange
(
1
,
10
,
2
)
fgraph
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
(
fgraph
,
[])
compare_jax_and_py
(
[],
[
out
]
,
[])
def
test_arange_of_shape
():
def
test_arange_of_shape
():
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
out
=
ptb
.
arange
(
1
,
x
.
shape
[
-
1
],
2
)
out
=
ptb
.
arange
(
1
,
x
.
shape
[
-
1
],
2
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
compare_jax_and_py
([
x
],
[
out
],
[
np
.
zeros
((
5
,))],
jax_mode
=
"JAX"
)
compare_jax_and_py
(
fgraph
,
[
np
.
zeros
((
5
,))],
jax_mode
=
"JAX"
)
def
test_arange_nonconcrete
():
def
test_arange_nonconcrete
():
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
"""JAX cannot JIT-compile `jax.numpy.arange` when arguments are not concrete values."""
a
=
scalar
(
"a"
)
a
=
scalar
(
"a"
)
a
.
tag
.
test_value
=
10
a
_
test_value
=
10
out
=
ptb
.
arange
(
a
)
out
=
ptb
.
arange
(
a
)
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_jax_and_py
([
a
],
[
out
],
[
a_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
def
test_jax_Join
():
def
test_jax_Join
():
...
@@ -94,16 +82,17 @@ def test_jax_Join():
...
@@ -94,16 +82,17 @@ def test_jax_Join():
b
=
matrix
(
"b"
)
b
=
matrix
(
"b"
)
x
=
ptb
.
join
(
0
,
a
,
b
)
x
=
ptb
.
join
(
0
,
a
,
b
)
x_fg
=
FunctionGraph
([
a
,
b
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
],
],
)
)
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
]]
.
astype
(
config
.
floatX
),
...
@@ -111,16 +100,17 @@ def test_jax_Join():
...
@@ -111,16 +100,17 @@ def test_jax_Join():
)
)
x
=
ptb
.
join
(
1
,
a
,
b
)
x
=
ptb
.
join
(
1
,
a
,
b
)
x_fg
=
FunctionGraph
([
a
,
b
],
[
x
])
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
],
],
)
)
compare_jax_and_py
(
compare_jax_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
...
@@ -132,9 +122,9 @@ class TestJaxSplit:
...
@@ -132,9 +122,9 @@ class TestJaxSplit:
def
test_basic
(
self
):
def
test_basic
(
self
):
a
=
matrix
(
"a"
)
a
=
matrix
(
"a"
)
a_splits
=
ptb
.
split
(
a
,
splits_size
=
[
1
,
2
,
3
],
n_splits
=
3
,
axis
=
0
)
a_splits
=
ptb
.
split
(
a
,
splits_size
=
[
1
,
2
,
3
],
n_splits
=
3
,
axis
=
0
)
fg
=
FunctionGraph
([
a
],
a_splits
)
compare_jax_and_py
(
compare_jax_and_py
(
fg
,
[
a
],
a_splits
,
[
[
np
.
zeros
((
6
,
4
))
.
astype
(
config
.
floatX
),
np
.
zeros
((
6
,
4
))
.
astype
(
config
.
floatX
),
],
],
...
@@ -142,9 +132,9 @@ class TestJaxSplit:
...
@@ -142,9 +132,9 @@ class TestJaxSplit:
a
=
matrix
(
"a"
,
shape
=
(
6
,
None
))
a
=
matrix
(
"a"
,
shape
=
(
6
,
None
))
a_splits
=
ptb
.
split
(
a
,
splits_size
=
[
2
,
a
.
shape
[
0
]
-
2
],
n_splits
=
2
,
axis
=
0
)
a_splits
=
ptb
.
split
(
a
,
splits_size
=
[
2
,
a
.
shape
[
0
]
-
2
],
n_splits
=
2
,
axis
=
0
)
fg
=
FunctionGraph
([
a
],
a_splits
)
compare_jax_and_py
(
compare_jax_and_py
(
fg
,
[
a
],
a_splits
,
[
[
np
.
zeros
((
6
,
4
))
.
astype
(
config
.
floatX
),
np
.
zeros
((
6
,
4
))
.
astype
(
config
.
floatX
),
],
],
...
@@ -207,15 +197,14 @@ class TestJaxSplit:
...
@@ -207,15 +197,14 @@ class TestJaxSplit:
def
test_jax_eye
():
def
test_jax_eye
():
"""Tests jaxification of the Eye operator"""
"""Tests jaxification of the Eye operator"""
out
=
ptb
.
eye
(
3
)
out
=
ptb
.
eye
(
3
)
out_fg
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
(
out_fg
,
[])
compare_jax_and_py
(
[],
[
out
]
,
[])
def
test_tri
():
def
test_tri
():
out
=
ptb
.
tri
(
10
,
10
,
0
)
out
=
ptb
.
tri
(
10
,
10
,
0
)
fgraph
=
FunctionGraph
([],
[
out
])
compare_jax_and_py
(
fgraph
,
[])
compare_jax_and_py
(
[],
[
out
]
,
[])
@pytest.mark.skipif
(
@pytest.mark.skipif
(
...
@@ -230,14 +219,13 @@ def test_tri_nonconcrete():
...
@@ -230,14 +219,13 @@ def test_tri_nonconcrete():
scalar
(
"n"
,
dtype
=
"int64"
),
scalar
(
"n"
,
dtype
=
"int64"
),
scalar
(
"k"
,
dtype
=
"int64"
),
scalar
(
"k"
,
dtype
=
"int64"
),
)
)
m
.
tag
.
test_value
=
10
m
_
test_value
=
10
n
.
tag
.
test_value
=
10
n
_
test_value
=
10
k
.
tag
.
test_value
=
0
k
_
test_value
=
0
out
=
ptb
.
tri
(
m
,
n
,
k
)
out
=
ptb
.
tri
(
m
,
n
,
k
)
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but
# The actual error the user will see should be jax.errors.ConcretizationTypeError, but
# the error handler raises an Attribute error first, so that's what this test needs to pass
# the error handler raises an Attribute error first, so that's what this test needs to pass
with
pytest
.
raises
(
AttributeError
):
with
pytest
.
raises
(
AttributeError
):
fgraph
=
FunctionGraph
([
m
,
n
,
k
],
[
out
])
compare_jax_and_py
([
m
,
n
,
k
],
[
out
],
[
m_test_value
,
n_test_value
,
k_test_value
])
compare_jax_and_py
(
fgraph
,
[
get_test_value
(
i
)
for
i
in
fgraph
.
inputs
])
tests/link/numba/test_basic.py
浏览文件 @
cc8c4992
import
contextlib
import
contextlib
import
inspect
import
inspect
from
collections.abc
import
Callable
,
Sequenc
e
from
collections.abc
import
Callable
,
Iterabl
e
from
typing
import
TYPE_CHECKING
,
Any
from
typing
import
TYPE_CHECKING
,
Any
from
unittest
import
mock
from
unittest
import
mock
...
@@ -21,10 +21,8 @@ from pytensor.compile.builders import OpFromGraph
...
@@ -21,10 +21,8 @@ from pytensor.compile.builders import OpFromGraph
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.compile.mode
import
Mode
from
pytensor.compile.mode
import
Mode
from
pytensor.compile.ops
import
ViewOp
from
pytensor.compile.ops
import
ViewOp
from
pytensor.compile.sharedvalue
import
SharedVariable
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.basic
import
Apply
,
Constant
from
pytensor.graph.op
import
Op
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
Op
,
get_test_value
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
from
pytensor.graph.rewriting.db
import
RewriteDatabaseQuery
from
pytensor.graph.type
import
Type
from
pytensor.graph.type
import
Type
from
pytensor.ifelse
import
ifelse
from
pytensor.ifelse
import
ifelse
...
@@ -39,7 +37,6 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
...
@@ -39,7 +37,6 @@ from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
from
pytensor.graph.basic
import
Variable
from
pytensor.graph.basic
import
Variable
from
pytensor.tensor
import
TensorLike
class
MyType
(
Type
):
class
MyType
(
Type
):
...
@@ -128,11 +125,6 @@ py_mode = Mode("py", opts)
...
@@ -128,11 +125,6 @@ py_mode = Mode("py", opts)
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
def
set_test_value
(
x
,
v
):
x
.
tag
.
test_value
=
v
return
x
def
compare_shape_dtype
(
x
,
y
):
def
compare_shape_dtype
(
x
,
y
):
return
x
.
shape
==
y
.
shape
and
x
.
dtype
==
y
.
dtype
return
x
.
shape
==
y
.
shape
and
x
.
dtype
==
y
.
dtype
...
@@ -225,28 +217,30 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
...
@@ -225,28 +217,30 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
def
compare_numba_and_py
(
def
compare_numba_and_py
(
fgraph
:
FunctionGraph
|
tuple
[
Sequence
[
"Variable"
],
Sequence
[
"Variable"
]
],
graph_inputs
:
Iterable
[
Variable
],
inputs
:
Sequence
[
"TensorLike"
],
graph_outputs
:
Variable
|
Iterable
[
Variable
],
assert_fn
:
Callable
|
None
=
Non
e
,
test_inputs
:
Iterabl
e
,
*
,
*
,
assert_fn
:
Callable
|
None
=
None
,
numba_mode
=
numba_mode
,
numba_mode
=
numba_mode
,
py_mode
=
py_mode
,
py_mode
=
py_mode
,
updates
=
None
,
updates
=
None
,
inplace
:
bool
=
False
,
inplace
:
bool
=
False
,
eval_obj_mode
:
bool
=
True
,
eval_obj_mode
:
bool
=
True
,
)
->
tuple
[
Callable
,
Any
]:
)
->
tuple
[
Callable
,
Any
]:
"""Function to compare python
graph
output and Numba compiled output for testing equality
"""Function to compare python
function
output and Numba compiled output for testing equality
In the tests below computational graphs are defined in PyTensor. These graphs are then passed to
The inputs and outputs are then passed to this function which then compiles the given function in both
this function which then compiles the graphs in both Numba and python, runs the calculation
numba and python, runs the calculation in both and checks if the results are the same
in both and checks if the results are the same
Parameters
Parameters
----------
----------
fgraph
graph_inputs:
`FunctionGraph` or tuple(inputs, outputs) to compare.
Symbolic inputs to the graph
inputs
graph_outputs:
Numeric inputs to be passed to the compiled graphs.
Symbolic outputs of the graph
test_inputs
Numerical inputs with which to evaluate the graph.
assert_fn
assert_fn
Assert function used to check for equality between python and Numba. If not
Assert function used to check for equality between python and Numba. If not
provided uses `np.testing.assert_allclose`.
provided uses `np.testing.assert_allclose`.
...
@@ -267,42 +261,38 @@ def compare_numba_and_py(
...
@@ -267,42 +261,38 @@ def compare_numba_and_py(
x
,
y
x
,
y
)
)
if
isinstance
(
fgraph
,
FunctionGraph
):
if
any
(
inp
.
owner
is
not
None
for
inp
in
graph_inputs
):
fn_inputs
=
fgraph
.
inputs
raise
ValueError
(
"Inputs must be root variables"
)
fn_outputs
=
fgraph
.
outputs
else
:
fn_inputs
,
fn_outputs
=
fgraph
fn_inputs
=
[
i
for
i
in
fn_inputs
if
not
isinstance
(
i
,
SharedVariable
)]
pytensor_py_fn
=
function
(
pytensor_py_fn
=
function
(
fn_inputs
,
fn
_outputs
,
mode
=
py_mode
,
accept_inplace
=
True
,
updates
=
updates
graph_inputs
,
graph
_outputs
,
mode
=
py_mode
,
accept_inplace
=
True
,
updates
=
updates
)
)
test_inputs
=
(
inp
.
copy
()
for
inp
in
inputs
)
if
inplace
else
inputs
test_inputs
_copy
=
(
inp
.
copy
()
for
inp
in
test_inputs
)
if
inplace
else
test_
inputs
py_res
=
pytensor_py_fn
(
*
test_inputs
)
py_res
=
pytensor_py_fn
(
*
test_inputs
_copy
)
# Get some coverage (and catch errors in python mode before unreadable numba ones)
# Get some coverage (and catch errors in python mode before unreadable numba ones)
if
eval_obj_mode
:
if
eval_obj_mode
:
test_inputs
=
(
inp
.
copy
()
for
inp
in
inputs
)
if
inplace
else
inputs
test_inputs_copy
=
(
eval_python_only
(
fn_inputs
,
fn_outputs
,
test_inputs
,
mode
=
numba_mode
)
(
inp
.
copy
()
for
inp
in
test_inputs
)
if
inplace
else
test_inputs
)
eval_python_only
(
graph_inputs
,
graph_outputs
,
test_inputs_copy
,
mode
=
numba_mode
)
pytensor_numba_fn
=
function
(
pytensor_numba_fn
=
function
(
fn
_inputs
,
graph
_inputs
,
fn
_outputs
,
graph
_outputs
,
mode
=
numba_mode
,
mode
=
numba_mode
,
accept_inplace
=
True
,
accept_inplace
=
True
,
updates
=
updates
,
updates
=
updates
,
)
)
test_inputs_copy
=
(
inp
.
copy
()
for
inp
in
test_inputs
)
if
inplace
else
test_inputs
numba_res
=
pytensor_numba_fn
(
*
test_inputs_copy
)
test_inputs
=
(
inp
.
copy
()
for
inp
in
inputs
)
if
inplace
else
inputs
if
isinstance
(
graph_outputs
,
tuple
|
list
):
numba_res
=
pytensor_numba_fn
(
*
test_inputs
)
if
len
(
fn_outputs
)
>
1
:
for
j
,
p
in
zip
(
numba_res
,
py_res
,
strict
=
True
):
for
j
,
p
in
zip
(
numba_res
,
py_res
,
strict
=
True
):
assert_fn
(
j
,
p
)
assert_fn
(
j
,
p
)
else
:
else
:
assert_fn
(
numba_res
[
0
],
py_res
[
0
]
)
assert_fn
(
numba_res
,
py_res
)
return
pytensor_numba_fn
,
numba_res
return
pytensor_numba_fn
,
numba_res
...
@@ -380,53 +370,53 @@ def test_create_numba_signature(v, expected, force_scalar):
...
@@ -380,53 +370,53 @@ def test_create_numba_signature(v, expected, force_scalar):
)
)
def
test_Shape
(
x
,
i
):
def
test_Shape
(
x
,
i
):
g
=
Shape
()(
pt
.
as_tensor_variable
(
x
))
g
=
Shape
()(
pt
.
as_tensor_variable
(
x
))
g_fg
=
FunctionGraph
([],
[
g
])
compare_numba_and_py
(
g_fg
,
[])
compare_numba_and_py
(
[],
[
g
]
,
[])
g
=
Shape_i
(
i
)(
pt
.
as_tensor_variable
(
x
))
g
=
Shape_i
(
i
)(
pt
.
as_tensor_variable
(
x
))
g_fg
=
FunctionGraph
([],
[
g
])
compare_numba_and_py
(
g_fg
,
[])
compare_numba_and_py
(
[],
[
g
]
,
[])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"v, shape, ndim"
,
"v, shape, ndim"
,
[
[
(
set_test_value
(
pt
.
vector
(),
np
.
array
([
4
],
dtype
=
config
.
floatX
)),
(
),
0
),
(
(
pt
.
vector
(),
np
.
array
([
4
],
dtype
=
config
.
floatX
)),
((),
None
),
0
),
(
set_test_value
(
pt
.
vector
(),
np
.
arange
(
4
,
dtype
=
config
.
floatX
)),
(
2
,
2
),
2
),
(
(
pt
.
vector
(),
np
.
arange
(
4
,
dtype
=
config
.
floatX
)),
((
2
,
2
),
None
),
2
),
(
(
set_test_value
(
pt
.
vector
(),
np
.
arange
(
4
,
dtype
=
config
.
floatX
)),
(
pt
.
vector
(),
np
.
arange
(
4
,
dtype
=
config
.
floatX
)),
set_test_value
(
pt
.
lvector
(),
np
.
array
([
2
,
2
],
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
array
([
2
,
2
],
dtype
=
"int64"
)),
2
,
2
,
),
),
],
],
)
)
def
test_Reshape
(
v
,
shape
,
ndim
):
def
test_Reshape
(
v
,
shape
,
ndim
):
v
,
v_test_value
=
v
shape
,
shape_test_value
=
shape
g
=
Reshape
(
ndim
)(
v
,
shape
)
g
=
Reshape
(
ndim
)(
v
,
shape
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
inputs
=
[
v
]
if
not
isinstance
(
shape
,
Variable
)
else
[
v
,
shape
]
test_values
=
(
[
v_test_value
]
if
not
isinstance
(
shape
,
Variable
)
else
[
v_test_value
,
shape_test_value
]
)
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
inputs
,
[
[
g
],
i
.
tag
.
test_value
test_values
,
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
def
test_Reshape_scalar
():
def
test_Reshape_scalar
():
v
=
pt
.
vector
()
v
=
pt
.
vector
()
v
.
tag
.
test_value
=
np
.
array
([
1.0
],
dtype
=
config
.
floatX
)
v
_
test_value
=
np
.
array
([
1.0
],
dtype
=
config
.
floatX
)
g
=
Reshape
(
1
)(
v
[
0
],
(
1
,))
g
=
Reshape
(
1
)(
v
[
0
],
(
1
,))
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
g
,
i
.
tag
.
test_value
[
v_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -434,53 +424,44 @@ def test_Reshape_scalar():
...
@@ -434,53 +424,44 @@ def test_Reshape_scalar():
"v, shape, fails"
,
"v, shape, fails"
,
[
[
(
(
set_test_value
(
pt
.
matrix
(),
np
.
array
([[
1.0
]],
dtype
=
config
.
floatX
)),
(
pt
.
matrix
(),
np
.
array
([[
1.0
]],
dtype
=
config
.
floatX
)),
(
1
,
1
),
(
1
,
1
),
False
,
False
,
),
),
(
(
set_test_value
(
pt
.
matrix
(),
np
.
array
([[
1.0
,
2.0
]],
dtype
=
config
.
floatX
)),
(
pt
.
matrix
(),
np
.
array
([[
1.0
,
2.0
]],
dtype
=
config
.
floatX
)),
(
1
,
1
),
(
1
,
1
),
True
,
True
,
),
),
(
(
set_test_value
(
pt
.
matrix
(),
np
.
array
([[
1.0
,
2.0
]],
dtype
=
config
.
floatX
)),
(
pt
.
matrix
(),
np
.
array
([[
1.0
,
2.0
]],
dtype
=
config
.
floatX
)),
(
1
,
None
),
(
1
,
None
),
False
,
False
,
),
),
],
],
)
)
def
test_SpecifyShape
(
v
,
shape
,
fails
):
def
test_SpecifyShape
(
v
,
shape
,
fails
):
v
,
v_test_value
=
v
g
=
SpecifyShape
()(
v
,
*
shape
)
g
=
SpecifyShape
()(
v
,
*
shape
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
not
fails
else
pytest
.
raises
(
AssertionError
)
cm
=
contextlib
.
suppress
()
if
not
fails
else
pytest
.
raises
(
AssertionError
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
[
g
],
i
.
tag
.
test_value
[
v_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
@pytest.mark.parametrize
(
def
test_ViewOp
():
"v"
,
v
=
pt
.
vector
()
[
v_test_value
=
np
.
arange
(
4
,
dtype
=
config
.
floatX
)
set_test_value
(
pt
.
vector
(),
np
.
arange
(
4
,
dtype
=
config
.
floatX
)),
],
)
def
test_ViewOp
(
v
):
g
=
ViewOp
()(
v
)
g
=
ViewOp
()(
v
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
[
g
],
i
.
tag
.
test_value
[
v_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -489,20 +470,16 @@ def test_ViewOp(v):
...
@@ -489,20 +470,16 @@ def test_ViewOp(v):
[
[
(
(
[
[
set_test_value
(
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)
(
pt
.
lmatrix
(),
rng
.
poisson
(
size
=
(
2
,
3
))),
),
set_test_value
(
pt
.
lmatrix
(),
rng
.
poisson
(
size
=
(
2
,
3
))),
],
],
MySingleOut
,
MySingleOut
,
UserWarning
,
UserWarning
,
),
),
(
(
[
[
set_test_value
(
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)
(
pt
.
lmatrix
(),
rng
.
poisson
(
size
=
(
2
,
3
))),
),
set_test_value
(
pt
.
lmatrix
(),
rng
.
poisson
(
size
=
(
2
,
3
))),
],
],
MyMultiOut
,
MyMultiOut
,
UserWarning
,
UserWarning
,
...
@@ -510,38 +487,32 @@ def test_ViewOp(v):
...
@@ -510,38 +487,32 @@ def test_ViewOp(v):
],
],
)
)
def
test_perform
(
inputs
,
op
,
exc
):
def
test_perform
(
inputs
,
op
,
exc
):
inputs
,
test_values
=
zip
(
*
inputs
,
strict
=
True
)
g
=
op
()(
*
inputs
)
g
=
op
()(
*
inputs
)
if
isinstance
(
g
,
list
):
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
outputs
=
g
else
:
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
outputs
=
[
g
]
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
inputs
,
[
outputs
,
i
.
tag
.
test_value
test_values
,
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
def
test_perform_params
():
def
test_perform_params
():
"""This tests for `Op.perform` implementations that require the `params` arguments."""
"""This tests for `Op.perform` implementations that require the `params` arguments."""
x
=
pt
.
vector
()
x
=
pt
.
vector
(
shape
=
(
2
,)
)
x
.
tag
.
test_value
=
np
.
array
([
1.0
,
2.0
],
dtype
=
config
.
floatX
)
x
_
test_value
=
np
.
array
([
1.0
,
2.0
],
dtype
=
config
.
floatX
)
out
=
assert_op
(
x
,
np
.
array
(
True
))
out
=
assert_op
(
x
,
np
.
array
(
True
))
if
not
isinstance
(
out
,
list
|
tuple
):
compare_numba_and_py
([
x
],
out
,
[
x_test_value
])
out
=
[
out
]
out_fg
=
FunctionGraph
([
x
],
out
)
compare_numba_and_py
(
out_fg
,
[
get_test_value
(
i
)
for
i
in
out_fg
.
inputs
])
def
test_perform_type_convert
():
def
test_perform_type_convert
():
...
@@ -552,59 +523,50 @@ def test_perform_type_convert():
...
@@ -552,59 +523,50 @@ def test_perform_type_convert():
"""
"""
x
=
pt
.
vector
()
x
=
pt
.
vector
()
x
.
tag
.
test_value
=
np
.
array
([
1.0
,
2.0
],
dtype
=
config
.
floatX
)
x
_
test_value
=
np
.
array
([
1.0
,
2.0
],
dtype
=
config
.
floatX
)
out
=
assert_op
(
x
.
sum
(),
np
.
array
(
True
))
out
=
assert_op
(
x
.
sum
(),
np
.
array
(
True
))
if
not
isinstance
(
out
,
list
|
tuple
):
compare_numba_and_py
([
x
],
out
,
[
x_test_value
])
out
=
[
out
]
out_fg
=
FunctionGraph
([
x
],
out
)
compare_numba_and_py
(
out_fg
,
[
get_test_value
(
i
)
for
i
in
out_fg
.
inputs
])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"x, y, exc"
,
"x, y, exc"
,
[
[
(
(
set_test_value
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
config
.
floatX
)),
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
config
.
floatX
)),
set_test_value
(
pt
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
(
pt
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
matrix
(
dtype
=
"float64"
),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
"float64"
)),
pt
.
matrix
(
dtype
=
"float64"
),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
"float64"
)
(
pt
.
vector
(
dtype
=
"float32"
),
rng
.
random
(
size
=
(
2
,))
.
astype
(
"float32"
)),
),
set_test_value
(
pt
.
vector
(
dtype
=
"float32"
),
rng
.
random
(
size
=
(
2
,))
.
astype
(
"float32"
)
),
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
lmatrix
(),
rng
.
poisson
(
size
=
(
3
,
2
))),
(
pt
.
lmatrix
(),
rng
.
poisson
(
size
=
(
3
,
2
))),
set_test_value
(
pt
.
fvector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
"float32"
)),
(
pt
.
fvector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
"float32"
)),
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
lvector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
np
.
int64
)),
(
pt
.
lvector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
np
.
int64
)),
set_test_value
(
pt
.
lvector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
np
.
int64
)),
(
pt
.
lvector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
np
.
int64
)),
None
,
None
,
),
),
],
],
)
)
def
test_Dot
(
x
,
y
,
exc
):
def
test_Dot
(
x
,
y
,
exc
):
x
,
x_test_value
=
x
y
,
y_test_value
=
y
g
=
ptm
.
Dot
()(
x
,
y
)
g
=
ptm
.
Dot
()(
x
,
y
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
,
y
],
[
[
g
],
i
.
tag
.
test_value
[
x_test_value
,
y_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -612,44 +574,41 @@ def test_Dot(x, y, exc):
...
@@ -612,44 +574,41 @@ def test_Dot(x, y, exc):
"x, exc"
,
"x, exc"
,
[
[
(
(
set_test_value
(
ps
.
float64
(),
np
.
array
(
0.0
,
dtype
=
"float64"
)),
(
ps
.
float64
(),
np
.
array
(
0.0
,
dtype
=
"float64"
)),
None
,
None
,
),
),
(
(
set_test_value
(
ps
.
float64
(),
np
.
array
(
-
32.0
,
dtype
=
"float64"
)),
(
ps
.
float64
(),
np
.
array
(
-
32.0
,
dtype
=
"float64"
)),
None
,
None
,
),
),
(
(
set_test_value
(
ps
.
float64
(),
np
.
array
(
-
40.0
,
dtype
=
"float64"
)),
(
ps
.
float64
(),
np
.
array
(
-
40.0
,
dtype
=
"float64"
)),
None
,
None
,
),
),
(
(
set_test_value
(
ps
.
float64
(),
np
.
array
(
32.0
,
dtype
=
"float64"
)),
(
ps
.
float64
(),
np
.
array
(
32.0
,
dtype
=
"float64"
)),
None
,
None
,
),
),
(
(
set_test_value
(
ps
.
float64
(),
np
.
array
(
40.0
,
dtype
=
"float64"
)),
(
ps
.
float64
(),
np
.
array
(
40.0
,
dtype
=
"float64"
)),
None
,
None
,
),
),
(
(
set_test_value
(
ps
.
int64
(),
np
.
array
(
32
,
dtype
=
"int64"
)),
(
ps
.
int64
(),
np
.
array
(
32
,
dtype
=
"int64"
)),
None
,
None
,
),
),
],
],
)
)
def
test_Softplus
(
x
,
exc
):
def
test_Softplus
(
x
,
exc
):
x
,
x_test_value
=
x
g
=
psm
.
Softplus
(
ps
.
upgrade_to_float
)(
x
)
g
=
psm
.
Softplus
(
ps
.
upgrade_to_float
)(
x
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
[
g
],
i
.
tag
.
test_value
[
x_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -657,22 +616,22 @@ def test_Softplus(x, exc):
...
@@ -657,22 +616,22 @@ def test_Softplus(x, exc):
"x, y, exc"
,
"x, y, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dtensor3
(),
pt
.
dtensor3
(),
rng
.
random
(
size
=
(
2
,
3
,
3
))
.
astype
(
"float64"
),
rng
.
random
(
size
=
(
2
,
3
,
3
))
.
astype
(
"float64"
),
),
),
set_test_value
(
(
pt
.
dtensor3
(),
pt
.
dtensor3
(),
rng
.
random
(
size
=
(
2
,
3
,
3
))
.
astype
(
"float64"
),
rng
.
random
(
size
=
(
2
,
3
,
3
))
.
astype
(
"float64"
),
),
),
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
dtensor3
(),
pt
.
dtensor3
(),
rng
.
random
(
size
=
(
2
,
3
,
3
))
.
astype
(
"float64"
),
rng
.
random
(
size
=
(
2
,
3
,
3
))
.
astype
(
"float64"
),
),
),
set_test_value
(
(
pt
.
ltensor3
(),
pt
.
ltensor3
(),
rng
.
poisson
(
size
=
(
2
,
3
,
3
))
.
astype
(
"int64"
),
rng
.
poisson
(
size
=
(
2
,
3
,
3
))
.
astype
(
"int64"
),
),
),
...
@@ -681,22 +640,17 @@ def test_Softplus(x, exc):
...
@@ -681,22 +640,17 @@ def test_Softplus(x, exc):
],
],
)
)
def
test_BatchedDot
(
x
,
y
,
exc
):
def
test_BatchedDot
(
x
,
y
,
exc
):
g
=
blas
.
BatchedDot
()(
x
,
y
)
x
,
x_test_value
=
x
y
,
y_test_value
=
y
if
isinstance
(
g
,
list
):
g
=
blas
.
BatchedDot
()(
x
,
y
)
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
,
y
],
[
g
,
i
.
tag
.
test_value
[
x_test_value
,
y_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -767,15 +721,15 @@ y = np.array(
...
@@ -767,15 +721,15 @@ y = np.array(
[
[
([],
lambda
:
np
.
array
(
True
),
np
.
r_
[
1
,
2
,
3
],
np
.
r_
[
-
1
,
-
2
,
-
3
]),
([],
lambda
:
np
.
array
(
True
),
np
.
r_
[
1
,
2
,
3
],
np
.
r_
[
-
1
,
-
2
,
-
3
]),
(
(
[
set_test_value
(
pt
.
dscalar
(),
np
.
array
(
0.2
,
dtype
=
np
.
float64
))],
[(
pt
.
dscalar
(),
np
.
array
(
0.2
,
dtype
=
np
.
float64
))],
lambda
x
:
x
<
0.5
,
lambda
x
:
x
<
0.5
,
np
.
r_
[
1
,
2
,
3
],
np
.
r_
[
1
,
2
,
3
],
np
.
r_
[
-
1
,
-
2
,
-
3
],
np
.
r_
[
-
1
,
-
2
,
-
3
],
),
),
(
(
[
[
set_test_value
(
pt
.
dscalar
(),
np
.
array
(
0.3
,
dtype
=
np
.
float64
)),
(
pt
.
dscalar
(),
np
.
array
(
0.3
,
dtype
=
np
.
float64
)),
set_test_value
(
pt
.
dscalar
(),
np
.
array
(
0.5
,
dtype
=
np
.
float64
)),
(
pt
.
dscalar
(),
np
.
array
(
0.5
,
dtype
=
np
.
float64
)),
],
],
lambda
x
,
y
:
x
>
y
,
lambda
x
,
y
:
x
>
y
,
x
,
x
,
...
@@ -783,8 +737,8 @@ y = np.array(
...
@@ -783,8 +737,8 @@ y = np.array(
),
),
(
(
[
[
set_test_value
(
pt
.
dvector
(),
np
.
array
([
0.3
,
0.1
],
dtype
=
np
.
float64
)),
(
pt
.
dvector
(),
np
.
array
([
0.3
,
0.1
],
dtype
=
np
.
float64
)),
set_test_value
(
pt
.
dvector
(),
np
.
array
([
0.5
,
0.9
],
dtype
=
np
.
float64
)),
(
pt
.
dvector
(),
np
.
array
([
0.5
,
0.9
],
dtype
=
np
.
float64
)),
],
],
lambda
x
,
y
:
pt
.
all
(
x
>
y
),
lambda
x
,
y
:
pt
.
all
(
x
>
y
),
x
,
x
,
...
@@ -792,8 +746,8 @@ y = np.array(
...
@@ -792,8 +746,8 @@ y = np.array(
),
),
(
(
[
[
set_test_value
(
pt
.
dvector
(),
np
.
array
([
0.3
,
0.1
],
dtype
=
np
.
float64
)),
(
pt
.
dvector
(),
np
.
array
([
0.3
,
0.1
],
dtype
=
np
.
float64
)),
set_test_value
(
pt
.
dvector
(),
np
.
array
([
0.5
,
0.9
],
dtype
=
np
.
float64
)),
(
pt
.
dvector
(),
np
.
array
([
0.5
,
0.9
],
dtype
=
np
.
float64
)),
],
],
lambda
x
,
y
:
pt
.
all
(
x
>
y
),
lambda
x
,
y
:
pt
.
all
(
x
>
y
),
[
x
,
2
*
x
],
[
x
,
2
*
x
],
...
@@ -801,8 +755,8 @@ y = np.array(
...
@@ -801,8 +755,8 @@ y = np.array(
),
),
(
(
[
[
set_test_value
(
pt
.
dvector
(),
np
.
array
([
0.5
,
0.9
],
dtype
=
np
.
float64
)),
(
pt
.
dvector
(),
np
.
array
([
0.5
,
0.9
],
dtype
=
np
.
float64
)),
set_test_value
(
pt
.
dvector
(),
np
.
array
([
0.3
,
0.1
],
dtype
=
np
.
float64
)),
(
pt
.
dvector
(),
np
.
array
([
0.3
,
0.1
],
dtype
=
np
.
float64
)),
],
],
lambda
x
,
y
:
pt
.
all
(
x
>
y
),
lambda
x
,
y
:
pt
.
all
(
x
>
y
),
[
x
,
2
*
x
],
[
x
,
2
*
x
],
...
@@ -811,14 +765,9 @@ y = np.array(
...
@@ -811,14 +765,9 @@ y = np.array(
],
],
)
)
def
test_IfElse
(
inputs
,
cond_fn
,
true_vals
,
false_vals
):
def
test_IfElse
(
inputs
,
cond_fn
,
true_vals
,
false_vals
):
inputs
,
test_values
=
zip
(
*
inputs
,
strict
=
True
)
if
inputs
else
([],
[])
out
=
ifelse
(
cond_fn
(
*
inputs
),
true_vals
,
false_vals
)
out
=
ifelse
(
cond_fn
(
*
inputs
),
true_vals
,
false_vals
)
compare_numba_and_py
(
inputs
,
out
,
test_values
)
if
not
isinstance
(
out
,
list
):
out
=
[
out
]
out_fg
=
FunctionGraph
(
inputs
,
out
)
compare_numba_and_py
(
out_fg
,
[
get_test_value
(
i
)
for
i
in
out_fg
.
inputs
])
@pytest.mark.xfail
(
reason
=
"https://github.com/numba/numba/issues/7409"
)
@pytest.mark.xfail
(
reason
=
"https://github.com/numba/numba/issues/7409"
)
...
@@ -883,7 +832,7 @@ def test_OpFromGraph():
...
@@ -883,7 +832,7 @@ def test_OpFromGraph():
yv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
3
yv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
3
zv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
5
zv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
5
compare_numba_and_py
(
((
x
,
y
,
z
),
(
out
,))
,
[
xv
,
yv
,
zv
])
compare_numba_and_py
(
[
x
,
y
,
z
],
[
out
]
,
[
xv
,
yv
,
zv
])
@pytest.mark.filterwarnings
(
"error"
)
@pytest.mark.filterwarnings
(
"error"
)
...
...
tests/link/numba/test_blockwise.py
浏览文件 @
cc8c4992
...
@@ -27,7 +27,8 @@ def test_blockwise(core_op, shape_opt):
...
@@ -27,7 +27,8 @@ def test_blockwise(core_op, shape_opt):
)
)
x_test
=
np
.
eye
(
3
)
*
np
.
arange
(
1
,
6
)[:,
None
,
None
]
x_test
=
np
.
eye
(
3
)
*
np
.
arange
(
1
,
6
)[:,
None
,
None
]
compare_numba_and_py
(
compare_numba_and_py
(
([
x
],
outs
),
[
x
],
outs
,
[
x_test
],
[
x_test
],
numba_mode
=
mode
,
numba_mode
=
mode
,
eval_obj_mode
=
False
,
eval_obj_mode
=
False
,
...
...
tests/link/numba/test_elemwise.py
浏览文件 @
cc8c4992
...
@@ -11,10 +11,7 @@ import pytensor.tensor.math as ptm
...
@@ -11,10 +11,7 @@ import pytensor.tensor.math as ptm
from
pytensor
import
config
,
function
from
pytensor
import
config
,
function
from
pytensor.compile
import
get_mode
from
pytensor.compile
import
get_mode
from
pytensor.compile.ops
import
deep_copy_op
from
pytensor.compile.ops
import
deep_copy_op
from
pytensor.compile.sharedvalue
import
SharedVariable
from
pytensor.gradient
import
grad
from
pytensor.gradient
import
grad
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.scalar
import
float64
from
pytensor.scalar
import
float64
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
pytensor.tensor.elemwise
import
CAReduce
,
DimShuffle
,
Elemwise
from
pytensor.tensor.math
import
All
,
Any
,
Max
,
Min
,
Prod
,
ProdWithoutZeros
,
Sum
from
pytensor.tensor.math
import
All
,
Any
,
Max
,
Min
,
Prod
,
ProdWithoutZeros
,
Sum
...
@@ -22,7 +19,6 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
...
@@ -22,7 +19,6 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from
tests.link.numba.test_basic
import
(
from
tests.link.numba.test_basic
import
(
compare_numba_and_py
,
compare_numba_and_py
,
scalar_my_multi_out
,
scalar_my_multi_out
,
set_test_value
,
)
)
from
tests.tensor.test_elemwise
import
(
from
tests.tensor.test_elemwise
import
(
careduce_benchmark_tester
,
careduce_benchmark_tester
,
...
@@ -116,13 +112,13 @@ rng = np.random.default_rng(42849)
...
@@ -116,13 +112,13 @@ rng = np.random.default_rng(42849)
def
test_Elemwise
(
inputs
,
input_vals
,
output_fn
,
exc
):
def
test_Elemwise
(
inputs
,
input_vals
,
output_fn
,
exc
):
outputs
=
output_fn
(
*
inputs
)
outputs
=
output_fn
(
*
inputs
)
out_fg
=
FunctionGraph
(
outputs
=
[
outputs
]
if
not
isinstance
(
outputs
,
list
)
else
outputs
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
raises
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
raises
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
out_fg
,
input_vals
)
compare_numba_and_py
(
inputs
,
outputs
,
input_vals
,
)
@pytest.mark.xfail
(
reason
=
"Logic had to be reversed due to surprising segfaults"
)
@pytest.mark.xfail
(
reason
=
"Logic had to be reversed due to surprising segfaults"
)
...
@@ -135,7 +131,7 @@ def test_elemwise_runtime_broadcast():
...
@@ -135,7 +131,7 @@ def test_elemwise_runtime_broadcast():
[
[
# `{'drop': [], 'shuffle': [], 'augment': [0, 1]}`
# `{'drop': [], 'shuffle': [], 'augment': [0, 1]}`
(
(
set_test_value
(
(
pt
.
lscalar
(
name
=
"a"
),
pt
.
lscalar
(
name
=
"a"
),
np
.
array
(
1
,
dtype
=
np
.
int64
),
np
.
array
(
1
,
dtype
=
np
.
int64
),
),
),
...
@@ -144,21 +140,17 @@ def test_elemwise_runtime_broadcast():
...
@@ -144,21 +140,17 @@ def test_elemwise_runtime_broadcast():
# I.e. `a_pt.T`
# I.e. `a_pt.T`
# `{'drop': [], 'shuffle': [1, 0], 'augment': []}`
# `{'drop': [], 'shuffle': [1, 0], 'augment': []}`
(
(
set_test_value
(
(
pt
.
matrix
(
"a"
),
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)),
pt
.
matrix
(
"a"
),
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
),
(
1
,
0
),
(
1
,
0
),
),
),
# `{'drop': [], 'shuffle': [0, 1], 'augment': [2]}`
# `{'drop': [], 'shuffle': [0, 1], 'augment': [2]}`
(
(
set_test_value
(
(
pt
.
matrix
(
"a"
),
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)),
pt
.
matrix
(
"a"
),
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
),
(
1
,
0
,
"x"
),
(
1
,
0
,
"x"
),
),
),
# `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}`
# `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}`
(
(
set_test_value
(
(
pt
.
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
,
None
),
name
=
"a"
),
pt
.
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
,
None
),
name
=
"a"
),
np
.
array
([[[
1.0
,
2.0
]],
[[
3.0
,
4.0
]]],
dtype
=
config
.
floatX
),
np
.
array
([[[
1.0
,
2.0
]],
[[
3.0
,
4.0
]]],
dtype
=
config
.
floatX
),
),
),
...
@@ -167,21 +159,21 @@ def test_elemwise_runtime_broadcast():
...
@@ -167,21 +159,21 @@ def test_elemwise_runtime_broadcast():
# I.e. `a_pt.dimshuffle((0,))`
# I.e. `a_pt.dimshuffle((0,))`
# `{'drop': [1], 'shuffle': [0], 'augment': []}`
# `{'drop': [1], 'shuffle': [0], 'augment': []}`
(
(
set_test_value
(
(
pt
.
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
),
name
=
"a"
),
pt
.
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
),
name
=
"a"
),
np
.
array
([[
1.0
],
[
2.0
],
[
3.0
],
[
4.0
]],
dtype
=
config
.
floatX
),
np
.
array
([[
1.0
],
[
2.0
],
[
3.0
],
[
4.0
]],
dtype
=
config
.
floatX
),
),
),
(
0
,),
(
0
,),
),
),
(
(
set_test_value
(
(
pt
.
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
),
name
=
"a"
),
pt
.
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
),
name
=
"a"
),
np
.
array
([[
1.0
],
[
2.0
],
[
3.0
],
[
4.0
]],
dtype
=
config
.
floatX
),
np
.
array
([[
1.0
],
[
2.0
],
[
3.0
],
[
4.0
]],
dtype
=
config
.
floatX
),
),
),
(
0
,),
(
0
,),
),
),
(
(
set_test_value
(
(
pt
.
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
1
,
1
,
1
),
name
=
"a"
),
pt
.
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
1
,
1
,
1
),
name
=
"a"
),
np
.
array
([[[
1.0
]]],
dtype
=
config
.
floatX
),
np
.
array
([[[
1.0
]]],
dtype
=
config
.
floatX
),
),
),
...
@@ -190,15 +182,12 @@ def test_elemwise_runtime_broadcast():
...
@@ -190,15 +182,12 @@ def test_elemwise_runtime_broadcast():
],
],
)
)
def
test_Dimshuffle
(
v
,
new_order
):
def
test_Dimshuffle
(
v
,
new_order
):
v
,
v_test_value
=
v
g
=
v
.
dimshuffle
(
new_order
)
g
=
v
.
dimshuffle
(
new_order
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
[
g
],
i
.
tag
.
test_value
[
v_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -229,79 +218,68 @@ def test_Dimshuffle_non_contiguous():
...
@@ -229,79 +218,68 @@ def test_Dimshuffle_non_contiguous():
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
0
,
0
,
set_test_value
(
pt
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
(
pt
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
All
(
axis
)(
x
),
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
All
(
axis
)(
x
),
0
,
0
,
set_test_value
(
pt
.
vector
(
dtype
=
"bool"
),
np
.
array
([
False
,
True
,
False
])),
(
pt
.
vector
(
dtype
=
"bool"
),
np
.
array
([
False
,
True
,
False
])),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Any
(
axis
)(
x
),
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Any
(
axis
)(
x
),
0
,
0
,
set_test_value
(
pt
.
vector
(
dtype
=
"bool"
),
np
.
array
([
False
,
True
,
False
])),
(
pt
.
vector
(
dtype
=
"bool"
),
np
.
array
([
False
,
True
,
False
])),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
0
,
0
,
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
(
0
,
1
),
(
0
,
1
),
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
(
1
,
0
),
(
1
,
0
),
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
None
,
None
,
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Sum
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
1
,
1
,
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Prod
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Prod
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
(),
# Empty axes would normally be rewritten away, but we want to test it still works
(),
# Empty axes would normally be rewritten away, but we want to test it still works
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Prod
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Prod
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
None
,
None
,
set_test_value
(
(
pt
.
scalar
(),
np
.
array
(
99.0
,
dtype
=
config
.
floatX
)
pt
.
scalar
(),
np
.
array
(
99.0
,
dtype
=
config
.
floatX
),
),
# Scalar input would normally be rewritten away, but we want to test it still works
),
# Scalar input would normally be rewritten away, but we want to test it still works
),
),
(
(
...
@@ -309,77 +287,62 @@ def test_Dimshuffle_non_contiguous():
...
@@ -309,77 +287,62 @@ def test_Dimshuffle_non_contiguous():
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
0
,
0
,
set_test_value
(
pt
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
(
pt
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
ProdWithoutZeros
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
ProdWithoutZeros
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
0
,
0
,
set_test_value
(
pt
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
(
pt
.
vector
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Prod
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Prod
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
0
,
0
,
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Prod
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Prod
(
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
axis
=
axis
,
dtype
=
dtype
,
acc_dtype
=
acc_dtype
)(
x
),
)(
x
),
1
,
1
,
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Max
(
axis
)(
x
),
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Max
(
axis
)(
x
),
None
,
None
,
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Max
(
axis
)(
x
),
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Max
(
axis
)(
x
),
None
,
None
,
set_test_value
(
(
pt
.
lmatrix
(),
np
.
arange
(
3
*
2
,
dtype
=
np
.
int64
)
.
reshape
((
3
,
2
))),
pt
.
lmatrix
(),
np
.
arange
(
3
*
2
,
dtype
=
np
.
int64
)
.
reshape
((
3
,
2
))
),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Min
(
axis
)(
x
),
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Min
(
axis
)(
x
),
None
,
None
,
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
3
*
2
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
),
),
(
(
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Min
(
axis
)(
x
),
lambda
x
,
axis
=
None
,
dtype
=
None
,
acc_dtype
=
None
:
Min
(
axis
)(
x
),
None
,
None
,
set_test_value
(
(
pt
.
lmatrix
(),
np
.
arange
(
3
*
2
,
dtype
=
np
.
int64
)
.
reshape
((
3
,
2
))),
pt
.
lmatrix
(),
np
.
arange
(
3
*
2
,
dtype
=
np
.
int64
)
.
reshape
((
3
,
2
))
),
),
),
],
],
)
)
def
test_CAReduce
(
careduce_fn
,
axis
,
v
):
def
test_CAReduce
(
careduce_fn
,
axis
,
v
):
v
,
v_test_value
=
v
g
=
careduce_fn
(
v
,
axis
=
axis
)
g
=
careduce_fn
(
v
,
axis
=
axis
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
fn
,
_
=
compare_numba_and_py
(
fn
,
_
=
compare_numba_and_py
(
g_fg
,
[
v
],
[
[
g
],
i
.
tag
.
test_value
[
v_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
# Confirm CAReduce is in the compiled function
# Confirm CAReduce is in the compiled function
fn
.
dprint
()
#
fn.dprint()
[
node
]
=
fn
.
maker
.
fgraph
.
apply_nodes
[
node
]
=
fn
.
maker
.
fgraph
.
apply_nodes
assert
isinstance
(
node
.
op
,
CAReduce
)
assert
isinstance
(
node
.
op
,
CAReduce
)
...
@@ -387,102 +350,91 @@ def test_CAReduce(careduce_fn, axis, v):
...
@@ -387,102 +350,91 @@ def test_CAReduce(careduce_fn, axis, v):
def
test_scalar_Elemwise_Clip
():
def
test_scalar_Elemwise_Clip
():
a
=
pt
.
scalar
(
"a"
)
a
=
pt
.
scalar
(
"a"
)
b
=
pt
.
scalar
(
"b"
)
b
=
pt
.
scalar
(
"b"
)
inputs
=
[
a
,
b
]
z
=
pt
.
switch
(
1
,
a
,
b
)
z
=
pt
.
switch
(
1
,
a
,
b
)
c
=
pt
.
clip
(
z
,
1
,
3
)
c
=
pt
.
clip
(
z
,
1
,
3
)
c_fg
=
FunctionGraph
(
outputs
=
[
c
])
compare_numba_and_py
(
c_fg
,
[
1
,
1
])
compare_numba_and_py
(
inputs
,
[
c
]
,
[
1
,
1
])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"dy, sm, axis, exc"
,
"dy, sm, axis, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
array
([[
1
,
1
,
1
],
[
0
,
0
,
0
]],
dtype
=
config
.
floatX
)),
pt
.
matrix
(),
np
.
array
([[
1
,
1
,
1
],
[
0
,
0
,
0
]],
dtype
=
config
.
floatX
)
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
),
set_test_value
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
None
,
None
,
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
array
([[
1
,
1
,
1
],
[
0
,
0
,
0
]],
dtype
=
config
.
floatX
)),
pt
.
matrix
(),
np
.
array
([[
1
,
1
,
1
],
[
0
,
0
,
0
]],
dtype
=
config
.
floatX
)
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
),
set_test_value
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
0
,
0
,
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
array
([[
1
,
1
,
1
],
[
0
,
0
,
0
]],
dtype
=
config
.
floatX
)),
pt
.
matrix
(),
np
.
array
([[
1
,
1
,
1
],
[
0
,
0
,
0
]],
dtype
=
config
.
floatX
)
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
),
set_test_value
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
1
,
1
,
None
,
None
,
),
),
],
],
)
)
def
test_SoftmaxGrad
(
dy
,
sm
,
axis
,
exc
):
def
test_SoftmaxGrad
(
dy
,
sm
,
axis
,
exc
):
dy
,
dy_test_value
=
dy
sm
,
sm_test_value
=
sm
g
=
SoftmaxGrad
(
axis
=
axis
)(
dy
,
sm
)
g
=
SoftmaxGrad
(
axis
=
axis
)(
dy
,
sm
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
dy
,
sm
],
[
[
g
],
i
.
tag
.
test_value
[
dy_test_value
,
sm_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
def
test_SoftMaxGrad_constant_dy
():
def
test_SoftMaxGrad_constant_dy
():
dy
=
pt
.
constant
(
np
.
zeros
((
3
,),
dtype
=
config
.
floatX
))
dy
=
pt
.
constant
(
np
.
zeros
((
3
,),
dtype
=
config
.
floatX
))
sm
=
pt
.
vector
(
shape
=
(
3
,))
sm
=
pt
.
vector
(
shape
=
(
3
,))
inputs
=
[
sm
]
g
=
SoftmaxGrad
(
axis
=
None
)(
dy
,
sm
)
g
=
SoftmaxGrad
(
axis
=
None
)(
dy
,
sm
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
g_fg
,
[
np
.
ones
((
3
,),
dtype
=
config
.
floatX
)])
compare_numba_and_py
(
inputs
,
[
g
]
,
[
np
.
ones
((
3
,),
dtype
=
config
.
floatX
)])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"x, axis, exc"
,
"x, axis, exc"
,
[
[
(
(
set_test_value
(
pt
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
(
pt
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
None
,
None
,
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
None
,
None
,
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
0
,
0
,
None
,
None
,
),
),
],
],
)
)
def
test_Softmax
(
x
,
axis
,
exc
):
def
test_Softmax
(
x
,
axis
,
exc
):
x
,
x_test_value
=
x
g
=
Softmax
(
axis
=
axis
)(
x
)
g
=
Softmax
(
axis
=
axis
)(
x
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
[
g
],
i
.
tag
.
test_value
[
x_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -490,35 +442,32 @@ def test_Softmax(x, axis, exc):
...
@@ -490,35 +442,32 @@ def test_Softmax(x, axis, exc):
"x, axis, exc"
,
"x, axis, exc"
,
[
[
(
(
set_test_value
(
pt
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
(
pt
.
vector
(),
rng
.
random
(
size
=
(
2
,))
.
astype
(
config
.
floatX
)),
None
,
None
,
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
0
,
0
,
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
(
pt
.
matrix
(),
rng
.
random
(
size
=
(
2
,
3
))
.
astype
(
config
.
floatX
)),
1
,
1
,
None
,
None
,
),
),
],
],
)
)
def
test_LogSoftmax
(
x
,
axis
,
exc
):
def
test_LogSoftmax
(
x
,
axis
,
exc
):
x
,
x_test_value
=
x
g
=
LogSoftmax
(
axis
=
axis
)(
x
)
g
=
LogSoftmax
(
axis
=
axis
)(
x
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
[
g
],
i
.
tag
.
test_value
[
x_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -526,44 +475,37 @@ def test_LogSoftmax(x, axis, exc):
...
@@ -526,44 +475,37 @@ def test_LogSoftmax(x, axis, exc):
"x, axes, exc"
,
"x, axes, exc"
,
[
[
(
(
set_test_value
(
pt
.
dscalar
(),
np
.
array
(
0.0
,
dtype
=
"float64"
)),
(
pt
.
dscalar
(),
np
.
array
(
0.0
,
dtype
=
"float64"
)),
[],
[],
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
dvector
(),
rng
.
random
(
size
=
(
3
,))
.
astype
(
"float64"
)),
(
pt
.
dvector
(),
rng
.
random
(
size
=
(
3
,))
.
astype
(
"float64"
)),
[
0
],
[
0
],
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
dmatrix
(),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
"float64"
)),
(
pt
.
dmatrix
(),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
"float64"
)),
[
0
],
[
0
],
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
dmatrix
(),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
"float64"
)),
(
pt
.
dmatrix
(),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
"float64"
)),
[
0
,
1
],
[
0
,
1
],
None
,
None
,
),
),
],
],
)
)
def
test_Max
(
x
,
axes
,
exc
):
def
test_Max
(
x
,
axes
,
exc
):
x
,
x_test_value
=
x
g
=
ptm
.
Max
(
axes
)(
x
)
g
=
ptm
.
Max
(
axes
)(
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
[
g
],
i
.
tag
.
test_value
[
x_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -571,44 +513,37 @@ def test_Max(x, axes, exc):
...
@@ -571,44 +513,37 @@ def test_Max(x, axes, exc):
"x, axes, exc"
,
"x, axes, exc"
,
[
[
(
(
set_test_value
(
pt
.
dscalar
(),
np
.
array
(
0.0
,
dtype
=
"float64"
)),
(
pt
.
dscalar
(),
np
.
array
(
0.0
,
dtype
=
"float64"
)),
[],
[],
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
dvector
(),
rng
.
random
(
size
=
(
3
,))
.
astype
(
"float64"
)),
(
pt
.
dvector
(),
rng
.
random
(
size
=
(
3
,))
.
astype
(
"float64"
)),
[
0
],
[
0
],
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
dmatrix
(),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
"float64"
)),
(
pt
.
dmatrix
(),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
"float64"
)),
[
0
],
[
0
],
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
dmatrix
(),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
"float64"
)),
(
pt
.
dmatrix
(),
rng
.
random
(
size
=
(
3
,
2
))
.
astype
(
"float64"
)),
[
0
,
1
],
[
0
,
1
],
None
,
None
,
),
),
],
],
)
)
def
test_Argmax
(
x
,
axes
,
exc
):
def
test_Argmax
(
x
,
axes
,
exc
):
x
,
x_test_value
=
x
g
=
ptm
.
Argmax
(
axes
)(
x
)
g
=
ptm
.
Argmax
(
axes
)(
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
[
g
],
i
.
tag
.
test_value
[
x_test_value
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -636,7 +571,8 @@ def test_scalar_loop():
...
@@ -636,7 +571,8 @@ def test_scalar_loop():
with
pytest
.
warns
(
UserWarning
,
match
=
"object mode"
):
with
pytest
.
warns
(
UserWarning
,
match
=
"object mode"
):
compare_numba_and_py
(
compare_numba_and_py
(
([
x
],
[
elemwise_loop
]),
[
x
],
[
elemwise_loop
],
(
np
.
array
([
1
,
2
,
3
],
dtype
=
"float64"
),),
(
np
.
array
([
1
,
2
,
3
],
dtype
=
"float64"
),),
)
)
...
...
tests/link/numba/test_extra_ops.py
浏览文件 @
cc8c4992
...
@@ -5,11 +5,8 @@ import pytest
...
@@ -5,11 +5,8 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.compile.sharedvalue
import
SharedVariable
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
extra_ops
from
pytensor.tensor
import
extra_ops
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
set_test_value
from
tests.link.numba.test_basic
import
compare_numba_and_py
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
...
@@ -18,20 +15,17 @@ rng = np.random.default_rng(42849)
...
@@ -18,20 +15,17 @@ rng = np.random.default_rng(42849)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"val"
,
"val"
,
[
[
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
6
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
6
,
dtype
=
"int64"
)),
],
],
)
)
def
test_Bartlett
(
val
):
def
test_Bartlett
(
val
):
val
,
test_val
=
val
g
=
extra_ops
.
bartlett
(
val
)
g
=
extra_ops
.
bartlett
(
val
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
val
],
[
g
,
i
.
tag
.
test_value
[
test_val
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
assert_fn
=
lambda
x
,
y
:
np
.
testing
.
assert_allclose
(
x
,
y
,
atol
=
1e-15
),
assert_fn
=
lambda
x
,
y
:
np
.
testing
.
assert_allclose
(
x
,
y
,
atol
=
1e-15
),
)
)
...
@@ -40,97 +34,71 @@ def test_Bartlett(val):
...
@@ -40,97 +34,71 @@ def test_Bartlett(val):
"val, axis, mode"
,
"val, axis, mode"
,
[
[
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
1
))),
pt
.
matrix
(),
np
.
arange
(
3
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
1
))
),
1
,
1
,
"add"
,
"add"
,
),
),
(
(
set_test_value
(
(
pt
.
dtensor3
(),
np
.
arange
(
30
,
dtype
=
config
.
floatX
)
.
reshape
((
2
,
3
,
5
))),
pt
.
dtensor3
(),
np
.
arange
(
30
,
dtype
=
config
.
floatX
)
.
reshape
((
2
,
3
,
5
))
),
-
1
,
-
1
,
"add"
,
"add"
,
),
),
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
0
,
0
,
"add"
,
"add"
,
),
),
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
1
,
1
,
"add"
,
"add"
,
),
),
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
None
,
None
,
"add"
,
"add"
,
),
),
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
0
,
0
,
"mul"
,
"mul"
,
),
),
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
1
,
1
,
"mul"
,
"mul"
,
),
),
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))),
pt
.
matrix
(),
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
((
3
,
2
))
),
None
,
None
,
"mul"
,
"mul"
,
),
),
],
],
)
)
def
test_CumOp
(
val
,
axis
,
mode
):
def
test_CumOp
(
val
,
axis
,
mode
):
val
,
test_val
=
val
g
=
extra_ops
.
CumOp
(
axis
=
axis
,
mode
=
mode
)(
val
)
g
=
extra_ops
.
CumOp
(
axis
=
axis
,
mode
=
mode
)(
val
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
val
],
[
g
,
i
.
tag
.
test_value
[
test_val
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
@pytest.mark.parametrize
(
def
test_FillDiagonal
():
"a, val"
,
a
=
pt
.
lmatrix
(
"a"
)
[
test_a
=
np
.
zeros
((
10
,
2
),
dtype
=
"int64"
)
(
set_test_value
(
pt
.
lmatrix
(),
np
.
zeros
((
10
,
2
),
dtype
=
"int64"
)),
val
=
pt
.
lscalar
(
"val"
)
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
test_val
=
np
.
array
(
1
,
dtype
=
"int64"
)
)
],
)
def
test_FillDiagonal
(
a
,
val
):
g
=
extra_ops
.
FillDiagonal
()(
a
,
val
)
g
=
extra_ops
.
FillDiagonal
()(
a
,
val
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
a
,
val
],
[
g
,
i
.
tag
.
test_value
[
test_a
,
test_val
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -138,33 +106,32 @@ def test_FillDiagonal(a, val):
...
@@ -138,33 +106,32 @@ def test_FillDiagonal(a, val):
"a, val, offset"
,
"a, val, offset"
,
[
[
(
(
set_test_value
(
pt
.
lmatrix
(),
np
.
zeros
((
10
,
2
),
dtype
=
"int64"
)),
(
pt
.
lmatrix
(),
np
.
zeros
((
10
,
2
),
dtype
=
"int64"
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
-
1
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
-
1
,
dtype
=
"int64"
)),
),
),
(
(
set_test_value
(
pt
.
lmatrix
(),
np
.
zeros
((
10
,
2
),
dtype
=
"int64"
)),
(
pt
.
lmatrix
(),
np
.
zeros
((
10
,
2
),
dtype
=
"int64"
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
0
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
0
,
dtype
=
"int64"
)),
),
),
(
(
set_test_value
(
pt
.
lmatrix
(),
np
.
zeros
((
10
,
3
),
dtype
=
"int64"
)),
(
pt
.
lmatrix
(),
np
.
zeros
((
10
,
3
),
dtype
=
"int64"
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
),
),
],
],
)
)
def
test_FillDiagonalOffset
(
a
,
val
,
offset
):
def
test_FillDiagonalOffset
(
a
,
val
,
offset
):
a
,
test_a
=
a
val
,
test_val
=
val
offset
,
test_offset
=
offset
g
=
extra_ops
.
FillDiagonalOffset
()(
a
,
val
,
offset
)
g
=
extra_ops
.
FillDiagonalOffset
()(
a
,
val
,
offset
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
a
,
val
,
offset
],
[
g
,
i
.
tag
.
test_value
[
test_a
,
test_val
,
test_offset
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -172,65 +139,56 @@ def test_FillDiagonalOffset(a, val, offset):
...
@@ -172,65 +139,56 @@ def test_FillDiagonalOffset(a, val, offset):
"arr, shape, mode, order, exc"
,
"arr, shape, mode, order, exc"
,
[
[
(
(
tuple
(
set_test_value
(
pt
.
lscalar
(),
v
)
for
v
in
np
.
array
([
0
])),
tuple
((
pt
.
lscalar
(),
v
)
for
v
in
np
.
array
([
0
])),
set_test_value
(
pt
.
lvector
(),
np
.
array
([
2
])),
(
pt
.
lvector
(),
np
.
array
([
2
])),
"raise"
,
"raise"
,
"C"
,
"C"
,
None
,
None
,
),
),
(
(
tuple
(
set_test_value
(
pt
.
lscalar
(),
v
)
for
v
in
np
.
array
([
0
,
0
,
3
])),
tuple
((
pt
.
lscalar
(),
v
)
for
v
in
np
.
array
([
0
,
0
,
3
])),
set_test_value
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
"raise"
,
"raise"
,
"C"
,
"C"
,
None
,
None
,
),
),
(
(
tuple
(
tuple
((
pt
.
lvector
(),
v
)
for
v
in
np
.
array
([[
0
,
1
],
[
2
,
0
],
[
1
,
3
]])),
set_test_value
(
pt
.
lvector
(),
v
)
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
for
v
in
np
.
array
([[
0
,
1
],
[
2
,
0
],
[
1
,
3
]])
),
set_test_value
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
"raise"
,
"raise"
,
"C"
,
"C"
,
None
,
None
,
),
),
(
(
tuple
(
tuple
((
pt
.
lvector
(),
v
)
for
v
in
np
.
array
([[
0
,
1
],
[
2
,
0
],
[
1
,
3
]])),
set_test_value
(
pt
.
lvector
(),
v
)
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
for
v
in
np
.
array
([[
0
,
1
],
[
2
,
0
],
[
1
,
3
]])
),
set_test_value
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
"raise"
,
"raise"
,
"F"
,
"F"
,
NotImplementedError
,
NotImplementedError
,
),
),
(
(
tuple
(
tuple
(
set_test_value
(
pt
.
lvector
(),
v
)
(
pt
.
lvector
(),
v
)
for
v
in
np
.
array
([[
0
,
1
,
2
],
[
2
,
0
,
3
],
[
1
,
3
,
5
]])
for
v
in
np
.
array
([[
0
,
1
,
2
],
[
2
,
0
,
3
],
[
1
,
3
,
5
]])
),
),
set_test_value
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
"raise"
,
"raise"
,
"C"
,
"C"
,
ValueError
,
ValueError
,
),
),
(
(
tuple
(
tuple
(
set_test_value
(
pt
.
lvector
(),
v
)
(
pt
.
lvector
(),
v
)
for
v
in
np
.
array
([[
0
,
1
,
2
],
[
2
,
0
,
3
],
[
1
,
3
,
5
]])
for
v
in
np
.
array
([[
0
,
1
,
2
],
[
2
,
0
,
3
],
[
1
,
3
,
5
]])
),
),
set_test_value
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
"wrap"
,
"wrap"
,
"C"
,
"C"
,
None
,
None
,
),
),
(
(
tuple
(
tuple
(
set_test_value
(
pt
.
lvector
(),
v
)
(
pt
.
lvector
(),
v
)
for
v
in
np
.
array
([[
0
,
1
,
2
],
[
2
,
0
,
3
],
[
1
,
3
,
5
]])
for
v
in
np
.
array
([[
0
,
1
,
2
],
[
2
,
0
,
3
],
[
1
,
3
,
5
]])
),
),
set_test_value
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
(
pt
.
lvector
(),
np
.
array
([
2
,
3
,
4
])),
"clip"
,
"clip"
,
"C"
,
"C"
,
None
,
None
,
...
@@ -238,18 +196,16 @@ def test_FillDiagonalOffset(a, val, offset):
...
@@ -238,18 +196,16 @@ def test_FillDiagonalOffset(a, val, offset):
],
],
)
)
def
test_RavelMultiIndex
(
arr
,
shape
,
mode
,
order
,
exc
):
def
test_RavelMultiIndex
(
arr
,
shape
,
mode
,
order
,
exc
):
g
=
extra_ops
.
RavelMultiIndex
(
mode
,
order
)(
*
((
*
arr
,
shape
)))
arr
,
test_arr
=
zip
(
*
arr
,
strict
=
True
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
shape
,
test_shape
=
shape
g
=
extra_ops
.
RavelMultiIndex
(
mode
,
order
)(
*
arr
,
shape
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
raises
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
raises
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
*
arr
,
shape
],
[
g
,
i
.
tag
.
test_value
[
*
test_arr
,
test_shape
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -257,44 +213,42 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc):
...
@@ -257,44 +213,42 @@ def test_RavelMultiIndex(arr, shape, mode, order, exc):
"x, repeats, axis, exc"
,
"x, repeats, axis, exc"
,
[
[
(
(
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
0
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
0
,
dtype
=
"int64"
)),
None
,
None
,
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
lmatrix
(),
np
.
zeros
((
2
,
2
),
dtype
=
"int64"
)),
(
pt
.
lmatrix
(),
np
.
zeros
((
2
,
2
),
dtype
=
"int64"
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
None
,
None
,
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
lvector
(),
np
.
arange
(
2
,
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
arange
(
2
,
dtype
=
"int64"
)),
set_test_value
(
pt
.
lvector
(),
np
.
array
([
1
,
1
],
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
array
([
1
,
1
],
dtype
=
"int64"
)),
None
,
None
,
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
lmatrix
(),
np
.
zeros
((
2
,
2
),
dtype
=
"int64"
)),
(
pt
.
lmatrix
(),
np
.
zeros
((
2
,
2
),
dtype
=
"int64"
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
0
,
0
,
UserWarning
,
UserWarning
,
),
),
],
],
)
)
def
test_Repeat
(
x
,
repeats
,
axis
,
exc
):
def
test_Repeat
(
x
,
repeats
,
axis
,
exc
):
x
,
test_x
=
x
repeats
,
test_repeats
=
repeats
g
=
extra_ops
.
Repeat
(
axis
)(
x
,
repeats
)
g
=
extra_ops
.
Repeat
(
axis
)(
x
,
repeats
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
,
repeats
],
[
g
,
i
.
tag
.
test_value
[
test_x
,
test_repeats
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -302,7 +256,7 @@ def test_Repeat(x, repeats, axis, exc):
...
@@ -302,7 +256,7 @@ def test_Repeat(x, repeats, axis, exc):
"x, axis, return_index, return_inverse, return_counts, exc"
,
"x, axis, return_index, return_inverse, return_counts, exc"
,
[
[
(
(
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
"int64"
)),
None
,
None
,
False
,
False
,
False
,
False
,
...
@@ -310,7 +264,7 @@ def test_Repeat(x, repeats, axis, exc):
...
@@ -310,7 +264,7 @@ def test_Repeat(x, repeats, axis, exc):
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
lvector
(),
np
.
array
([
1
,
1
,
2
],
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
array
([
1
,
1
,
2
],
dtype
=
"int64"
)),
None
,
None
,
False
,
False
,
False
,
False
,
...
@@ -318,7 +272,7 @@ def test_Repeat(x, repeats, axis, exc):
...
@@ -318,7 +272,7 @@ def test_Repeat(x, repeats, axis, exc):
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
lmatrix
(),
np
.
array
([[
1
,
1
],
[
2
,
2
]],
dtype
=
"int64"
)),
(
pt
.
lmatrix
(),
np
.
array
([[
1
,
1
],
[
2
,
2
]],
dtype
=
"int64"
)),
None
,
None
,
False
,
False
,
False
,
False
,
...
@@ -326,9 +280,7 @@ def test_Repeat(x, repeats, axis, exc):
...
@@ -326,9 +280,7 @@ def test_Repeat(x, repeats, axis, exc):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
np
.
array
([[
1
,
1
],
[
1
,
1
],
[
2
,
2
]],
dtype
=
"int64"
)),
pt
.
lmatrix
(),
np
.
array
([[
1
,
1
],
[
1
,
1
],
[
2
,
2
]],
dtype
=
"int64"
)
),
0
,
0
,
False
,
False
,
False
,
False
,
...
@@ -336,9 +288,7 @@ def test_Repeat(x, repeats, axis, exc):
...
@@ -336,9 +288,7 @@ def test_Repeat(x, repeats, axis, exc):
UserWarning
,
UserWarning
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
np
.
array
([[
1
,
1
],
[
1
,
1
],
[
2
,
2
]],
dtype
=
"int64"
)),
pt
.
lmatrix
(),
np
.
array
([[
1
,
1
],
[
1
,
1
],
[
2
,
2
]],
dtype
=
"int64"
)
),
0
,
0
,
True
,
True
,
True
,
True
,
...
@@ -348,22 +298,15 @@ def test_Repeat(x, repeats, axis, exc):
...
@@ -348,22 +298,15 @@ def test_Repeat(x, repeats, axis, exc):
],
],
)
)
def
test_Unique
(
x
,
axis
,
return_index
,
return_inverse
,
return_counts
,
exc
):
def
test_Unique
(
x
,
axis
,
return_index
,
return_inverse
,
return_counts
,
exc
):
x
,
test_x
=
x
g
=
extra_ops
.
Unique
(
return_index
,
return_inverse
,
return_counts
,
axis
)(
x
)
g
=
extra_ops
.
Unique
(
return_index
,
return_inverse
,
return_counts
,
axis
)(
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -371,19 +314,19 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
...
@@ -371,19 +314,19 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
"arr, shape, order, exc"
,
"arr, shape, order, exc"
,
[
[
(
(
set_test_value
(
pt
.
lvector
(),
np
.
array
([
9
,
15
,
1
],
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
array
([
9
,
15
,
1
],
dtype
=
"int64"
)),
pt
.
as_tensor
([
2
,
3
,
4
]),
pt
.
as_tensor
([
2
,
3
,
4
]),
"C"
,
"C"
,
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
lvector
(),
np
.
array
([
1
,
0
],
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
array
([
1
,
0
],
dtype
=
"int64"
)),
pt
.
as_tensor
([
2
]),
pt
.
as_tensor
([
2
]),
"C"
,
"C"
,
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
lvector
(),
np
.
array
([
9
,
15
,
1
],
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
array
([
9
,
15
,
1
],
dtype
=
"int64"
)),
pt
.
as_tensor
([
2
,
3
,
4
]),
pt
.
as_tensor
([
2
,
3
,
4
]),
"F"
,
"F"
,
NotImplementedError
,
NotImplementedError
,
...
@@ -391,22 +334,15 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
...
@@ -391,22 +334,15 @@ def test_Unique(x, axis, return_index, return_inverse, return_counts, exc):
],
],
)
)
def
test_UnravelIndex
(
arr
,
shape
,
order
,
exc
):
def
test_UnravelIndex
(
arr
,
shape
,
order
,
exc
):
arr
,
test_arr
=
arr
g
=
extra_ops
.
UnravelIndex
(
order
)(
arr
,
shape
)
g
=
extra_ops
.
UnravelIndex
(
order
)(
arr
,
shape
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
raises
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
raises
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
arr
],
[
g
,
i
.
tag
.
test_value
[
test_arr
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -414,18 +350,18 @@ def test_UnravelIndex(arr, shape, order, exc):
...
@@ -414,18 +350,18 @@ def test_UnravelIndex(arr, shape, order, exc):
"a, v, side, sorter, exc"
,
"a, v, side, sorter, exc"
,
[
[
(
(
set_test_value
(
pt
.
vector
(),
np
.
array
([
1.0
,
2.0
,
3.0
],
dtype
=
config
.
floatX
)),
(
pt
.
vector
(),
np
.
array
([
1.0
,
2.0
,
3.0
],
dtype
=
config
.
floatX
)),
set_test_value
(
pt
.
matrix
(),
rng
.
random
((
3
,
2
))
.
astype
(
config
.
floatX
)),
(
pt
.
matrix
(),
rng
.
random
((
3
,
2
))
.
astype
(
config
.
floatX
)),
"left"
,
"left"
,
None
,
None
,
None
,
None
,
),
),
pytest
.
param
(
pytest
.
param
(
set_test_value
(
(
pt
.
vector
(),
pt
.
vector
(),
np
.
array
([
0.29769574
,
0.71649186
,
0.20475563
])
.
astype
(
config
.
floatX
),
np
.
array
([
0.29769574
,
0.71649186
,
0.20475563
])
.
astype
(
config
.
floatX
),
),
),
set_test_value
(
(
pt
.
matrix
(),
pt
.
matrix
(),
np
.
array
(
np
.
array
(
[
[
...
@@ -440,25 +376,26 @@ def test_UnravelIndex(arr, shape, order, exc):
...
@@ -440,25 +376,26 @@ def test_UnravelIndex(arr, shape, order, exc):
None
,
None
,
),
),
(
(
set_test_value
(
pt
.
vector
(),
np
.
array
([
1.0
,
2.0
,
3.0
],
dtype
=
config
.
floatX
)),
(
pt
.
vector
(),
np
.
array
([
1.0
,
2.0
,
3.0
],
dtype
=
config
.
floatX
)),
set_test_value
(
pt
.
matrix
(),
rng
.
random
((
3
,
2
))
.
astype
(
config
.
floatX
)),
(
pt
.
matrix
(),
rng
.
random
((
3
,
2
))
.
astype
(
config
.
floatX
)),
"right"
,
"right"
,
set_test_value
(
pt
.
lvector
(),
np
.
array
([
0
,
2
,
1
])),
(
pt
.
lvector
(),
np
.
array
([
0
,
2
,
1
])),
UserWarning
,
UserWarning
,
),
),
],
],
)
)
def
test_Searchsorted
(
a
,
v
,
side
,
sorter
,
exc
):
def
test_Searchsorted
(
a
,
v
,
side
,
sorter
,
exc
):
a
,
test_a
=
a
v
,
test_v
=
v
if
sorter
is
not
None
:
sorter
,
test_sorter
=
sorter
g
=
extra_ops
.
SearchsortedOp
(
side
)(
a
,
v
,
sorter
)
g
=
extra_ops
.
SearchsortedOp
(
side
)(
a
,
v
,
sorter
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
a
,
v
]
if
sorter
is
None
else
[
a
,
v
,
sorter
],
[
g
,
i
.
tag
.
test_value
[
test_a
,
test_v
]
if
sorter
is
None
else
[
test_a
,
test_v
,
test_sorter
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
tests/link/numba/test_nlinalg.py
浏览文件 @
cc8c4992
...
@@ -4,11 +4,8 @@ import numpy as np
...
@@ -4,11 +4,8 @@ import numpy as np
import
pytest
import
pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.compile.sharedvalue
import
SharedVariable
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
nlinalg
from
pytensor.tensor
import
nlinalg
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
set_test_value
from
tests.link.numba.test_basic
import
compare_numba_and_py
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
...
@@ -18,14 +15,14 @@ rng = np.random.default_rng(42849)
...
@@ -18,14 +15,14 @@ rng = np.random.default_rng(42849)
"x, exc"
,
"x, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
poisson
(
size
=
(
3
,
3
))
.
astype
(
"int64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
poisson
(
size
=
(
3
,
3
))
.
astype
(
"int64"
)),
),
),
...
@@ -34,18 +31,15 @@ rng = np.random.default_rng(42849)
...
@@ -34,18 +31,15 @@ rng = np.random.default_rng(42849)
],
],
)
)
def
test_Det
(
x
,
exc
):
def
test_Det
(
x
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
Det
()(
x
)
g
=
nlinalg
.
Det
()(
x
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -53,14 +47,14 @@ def test_Det(x, exc):
...
@@ -53,14 +47,14 @@ def test_Det(x, exc):
"x, exc"
,
"x, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
poisson
(
size
=
(
3
,
3
))
.
astype
(
"int64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
poisson
(
size
=
(
3
,
3
))
.
astype
(
"int64"
)),
),
),
...
@@ -69,18 +63,15 @@ def test_Det(x, exc):
...
@@ -69,18 +63,15 @@ def test_Det(x, exc):
],
],
)
)
def
test_SLogDet
(
x
,
exc
):
def
test_SLogDet
(
x
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
SLogDet
()(
x
)
g
=
nlinalg
.
SLogDet
()(
x
)
g_fg
=
FunctionGraph
(
outputs
=
g
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -112,21 +103,21 @@ y = np.array(
...
@@ -112,21 +103,21 @@ y = np.array(
"x, exc"
,
"x, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
x
),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
x
),
),
),
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
y
),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
y
),
),
),
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -137,22 +128,15 @@ y = np.array(
...
@@ -137,22 +128,15 @@ y = np.array(
],
],
)
)
def
test_Eig
(
x
,
exc
):
def
test_Eig
(
x
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
Eig
()(
x
)
g
=
nlinalg
.
Eig
()(
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -160,7 +144,7 @@ def test_Eig(x, exc):
...
@@ -160,7 +144,7 @@ def test_Eig(x, exc):
"x, uplo, exc"
,
"x, uplo, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -168,7 +152,7 @@ def test_Eig(x, exc):
...
@@ -168,7 +152,7 @@ def test_Eig(x, exc):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -180,22 +164,15 @@ def test_Eig(x, exc):
...
@@ -180,22 +164,15 @@ def test_Eig(x, exc):
],
],
)
)
def
test_Eigh
(
x
,
uplo
,
exc
):
def
test_Eigh
(
x
,
uplo
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
Eigh
(
uplo
)(
x
)
g
=
nlinalg
.
Eigh
(
uplo
)(
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -204,7 +181,7 @@ def test_Eigh(x, uplo, exc):
...
@@ -204,7 +181,7 @@ def test_Eigh(x, uplo, exc):
[
[
(
(
nlinalg
.
MatrixInverse
,
nlinalg
.
MatrixInverse
,
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -213,7 +190,7 @@ def test_Eigh(x, uplo, exc):
...
@@ -213,7 +190,7 @@ def test_Eigh(x, uplo, exc):
),
),
(
(
nlinalg
.
MatrixInverse
,
nlinalg
.
MatrixInverse
,
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -224,7 +201,7 @@ def test_Eigh(x, uplo, exc):
...
@@ -224,7 +201,7 @@ def test_Eigh(x, uplo, exc):
),
),
(
(
nlinalg
.
MatrixPinv
,
nlinalg
.
MatrixPinv
,
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -233,7 +210,7 @@ def test_Eigh(x, uplo, exc):
...
@@ -233,7 +210,7 @@ def test_Eigh(x, uplo, exc):
),
),
(
(
nlinalg
.
MatrixPinv
,
nlinalg
.
MatrixPinv
,
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -245,18 +222,15 @@ def test_Eigh(x, uplo, exc):
...
@@ -245,18 +222,15 @@ def test_Eigh(x, uplo, exc):
],
],
)
)
def
test_matrix_inverses
(
op
,
x
,
exc
,
op_args
):
def
test_matrix_inverses
(
op
,
x
,
exc
,
op_args
):
x
,
test_x
=
x
g
=
op
(
*
op_args
)(
x
)
g
=
op
(
*
op_args
)(
x
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -264,7 +238,7 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -264,7 +238,7 @@ def test_matrix_inverses(op, x, exc, op_args):
"x, mode, exc"
,
"x, mode, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -272,7 +246,7 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -272,7 +246,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -280,7 +254,7 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -280,7 +254,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -290,7 +264,7 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -290,7 +264,7 @@ def test_matrix_inverses(op, x, exc, op_args):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -302,22 +276,15 @@ def test_matrix_inverses(op, x, exc, op_args):
...
@@ -302,22 +276,15 @@ def test_matrix_inverses(op, x, exc, op_args):
],
],
)
)
def
test_QRFull
(
x
,
mode
,
exc
):
def
test_QRFull
(
x
,
mode
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
QRFull
(
mode
)(
x
)
g
=
nlinalg
.
QRFull
(
mode
)(
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -325,7 +292,7 @@ def test_QRFull(x, mode, exc):
...
@@ -325,7 +292,7 @@ def test_QRFull(x, mode, exc):
"x, full_matrices, compute_uv, exc"
,
"x, full_matrices, compute_uv, exc"
,
[
[
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -334,7 +301,7 @@ def test_QRFull(x, mode, exc):
...
@@ -334,7 +301,7 @@ def test_QRFull(x, mode, exc):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
random
(
size
=
(
3
,
3
))
.
astype
(
"float64"
)),
),
),
...
@@ -343,7 +310,7 @@ def test_QRFull(x, mode, exc):
...
@@ -343,7 +310,7 @@ def test_QRFull(x, mode, exc):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -354,7 +321,7 @@ def test_QRFull(x, mode, exc):
...
@@ -354,7 +321,7 @@ def test_QRFull(x, mode, exc):
None
,
None
,
),
),
(
(
set_test_value
(
(
pt
.
lmatrix
(),
pt
.
lmatrix
(),
(
lambda
x
:
x
.
T
.
dot
(
x
))(
(
lambda
x
:
x
.
T
.
dot
(
x
))(
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
rng
.
integers
(
1
,
10
,
size
=
(
3
,
3
))
.
astype
(
"int64"
)
...
@@ -367,20 +334,13 @@ def test_QRFull(x, mode, exc):
...
@@ -367,20 +334,13 @@ def test_QRFull(x, mode, exc):
],
],
)
)
def
test_SVD
(
x
,
full_matrices
,
compute_uv
,
exc
):
def
test_SVD
(
x
,
full_matrices
,
compute_uv
,
exc
):
x
,
test_x
=
x
g
=
nlinalg
.
SVD
(
full_matrices
,
compute_uv
)(
x
)
g
=
nlinalg
.
SVD
(
full_matrices
,
compute_uv
)(
x
)
if
isinstance
(
g
,
list
):
g_fg
=
FunctionGraph
(
outputs
=
g
)
else
:
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
cm
=
contextlib
.
suppress
()
if
exc
is
None
else
pytest
.
warns
(
exc
)
with
cm
:
with
cm
:
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
],
[
g
,
i
.
tag
.
test_value
[
test_x
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
tests/link/numba/test_pad.py
浏览文件 @
cc8c4992
...
@@ -3,7 +3,6 @@ import pytest
...
@@ -3,7 +3,6 @@ import pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.graph
import
FunctionGraph
from
pytensor.tensor.pad
import
PadMode
from
pytensor.tensor.pad
import
PadMode
from
tests.link.numba.test_basic
import
compare_numba_and_py
from
tests.link.numba.test_basic
import
compare_numba_and_py
...
@@ -58,10 +57,10 @@ def test_numba_pad(mode: PadMode, kwargs):
...
@@ -58,10 +57,10 @@ def test_numba_pad(mode: PadMode, kwargs):
x
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
x
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
res
=
pt
.
pad
(
x_pt
,
mode
=
mode
,
pad_width
=
3
,
**
kwargs
)
res
=
pt
.
pad
(
x_pt
,
mode
=
mode
,
pad_width
=
3
,
**
kwargs
)
res_fg
=
FunctionGraph
([
x_pt
],
[
res
])
compare_numba_and_py
(
compare_numba_and_py
(
res_fg
,
[
x_pt
],
[
res
],
[
x
],
[
x
],
assert_fn
=
lambda
x
,
y
:
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
RTOL
,
atol
=
ATOL
),
assert_fn
=
lambda
x
,
y
:
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
RTOL
,
atol
=
ATOL
),
py_mode
=
"FAST_RUN"
,
py_mode
=
"FAST_RUN"
,
...
...
tests/link/numba/test_random.py
浏览文件 @
cc8c4992
...
@@ -10,13 +10,9 @@ import pytensor.tensor.random.basic as ptr
...
@@ -10,13 +10,9 @@ import pytensor.tensor.random.basic as ptr
from
pytensor
import
shared
from
pytensor
import
shared
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.compile.sharedvalue
import
SharedVariable
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
tests.link.numba.test_basic
import
(
from
tests.link.numba.test_basic
import
(
compare_numba_and_py
,
compare_numba_and_py
,
numba_mode
,
numba_mode
,
set_test_value
,
)
)
from
tests.tensor.random.test_basic
import
(
from
tests.tensor.random.test_basic
import
(
batched_permutation_tester
,
batched_permutation_tester
,
...
@@ -159,11 +155,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -159,11 +155,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
uniform
,
ptr
.
uniform
,
[
[
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -173,15 +169,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -173,15 +169,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
triangular
,
ptr
.
triangular
,
[
[
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
-
5.0
,
dtype
=
np
.
float64
),
np
.
array
(
-
5.0
,
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
5.0
,
dtype
=
np
.
float64
),
np
.
array
(
5.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -191,11 +187,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -191,11 +187,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
lognormal
,
ptr
.
lognormal
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -205,11 +201,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -205,11 +201,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
pareto
,
ptr
.
pareto
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
2.0
,
10.0
],
dtype
=
np
.
float64
),
np
.
array
([
2.0
,
10.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -219,7 +215,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -219,7 +215,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
exponential
,
ptr
.
exponential
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -229,7 +225,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -229,7 +225,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
weibull
,
ptr
.
weibull
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -239,11 +235,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -239,11 +235,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
logistic
,
ptr
.
logistic
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -253,7 +249,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -253,7 +249,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
geometric
,
ptr
.
geometric
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
0.3
,
0.4
],
dtype
=
np
.
float64
),
np
.
array
([
0.3
,
0.4
],
dtype
=
np
.
float64
),
),
),
...
@@ -263,15 +259,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -263,15 +259,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
pytest
.
param
(
pytest
.
param
(
ptr
.
hypergeometric
,
ptr
.
hypergeometric
,
[
[
set_test_value
(
(
pt
.
lscalar
(),
pt
.
lscalar
(),
np
.
array
(
7
,
dtype
=
np
.
int64
),
np
.
array
(
7
,
dtype
=
np
.
int64
),
),
),
set_test_value
(
(
pt
.
lscalar
(),
pt
.
lscalar
(),
np
.
array
(
8
,
dtype
=
np
.
int64
),
np
.
array
(
8
,
dtype
=
np
.
int64
),
),
),
set_test_value
(
(
pt
.
lscalar
(),
pt
.
lscalar
(),
np
.
array
(
15
,
dtype
=
np
.
int64
),
np
.
array
(
15
,
dtype
=
np
.
int64
),
),
),
...
@@ -282,11 +278,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -282,11 +278,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
wald
,
ptr
.
wald
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -296,11 +292,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -296,11 +292,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
laplace
,
ptr
.
laplace
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -310,11 +306,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -310,11 +306,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
binomial
,
ptr
.
binomial
,
[
[
set_test_value
(
(
pt
.
lvector
(),
pt
.
lvector
(),
np
.
array
([
1
,
2
],
dtype
=
np
.
int64
),
np
.
array
([
1
,
2
],
dtype
=
np
.
int64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
0.9
,
dtype
=
np
.
float64
),
np
.
array
(
0.9
,
dtype
=
np
.
float64
),
),
),
...
@@ -324,21 +320,21 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -324,21 +320,21 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
normal
,
ptr
.
normal
,
[
[
set_test_value
(
(
pt
.
lvector
(),
pt
.
lvector
(),
np
.
array
([
1
,
2
],
dtype
=
np
.
int64
),
np
.
array
([
1
,
2
],
dtype
=
np
.
int64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
],
],
pt
.
as_tensor
(
tuple
(
set_test_value
(
pt
.
lscalar
(),
v
)
for
v
in
[
3
,
2
])
),
pt
.
as_tensor
(
[
3
,
2
]
),
),
),
(
(
ptr
.
poisson
,
ptr
.
poisson
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -348,11 +344,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -348,11 +344,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
halfnormal
,
ptr
.
halfnormal
,
[
[
set_test_value
(
(
pt
.
lvector
(),
pt
.
lvector
(),
np
.
array
([
1
,
2
],
dtype
=
np
.
int64
),
np
.
array
([
1
,
2
],
dtype
=
np
.
int64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -362,7 +358,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -362,7 +358,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
bernoulli
,
ptr
.
bernoulli
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
0.1
,
0.9
],
dtype
=
np
.
float64
),
np
.
array
([
0.1
,
0.9
],
dtype
=
np
.
float64
),
),
),
...
@@ -372,11 +368,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -372,11 +368,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
beta
,
ptr
.
beta
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -386,11 +382,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -386,11 +382,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
_gamma
,
ptr
.
_gamma
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
0.5
,
3.0
],
dtype
=
np
.
float64
),
np
.
array
([
0.5
,
3.0
],
dtype
=
np
.
float64
),
),
),
...
@@ -400,7 +396,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -400,7 +396,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
chisquare
,
ptr
.
chisquare
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
)
)
...
@@ -410,11 +406,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -410,11 +406,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
negative_binomial
,
ptr
.
negative_binomial
,
[
[
set_test_value
(
(
pt
.
lvector
(),
pt
.
lvector
(),
np
.
array
([
100
,
200
],
dtype
=
np
.
int64
),
np
.
array
([
100
,
200
],
dtype
=
np
.
int64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
0.09
,
dtype
=
np
.
float64
),
np
.
array
(
0.09
,
dtype
=
np
.
float64
),
),
),
...
@@ -424,11 +420,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -424,11 +420,11 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
vonmises
,
ptr
.
vonmises
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
-
0.5
,
0.5
],
dtype
=
np
.
float64
),
np
.
array
([
-
0.5
,
0.5
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -438,14 +434,14 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -438,14 +434,14 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
(
(
ptr
.
permutation
,
ptr
.
permutation
,
[
[
set_test_value
(
pt
.
dmatrix
(),
np
.
eye
(
5
,
dtype
=
np
.
float64
)),
(
pt
.
dmatrix
(),
np
.
eye
(
5
,
dtype
=
np
.
float64
)),
],
],
(),
(),
),
),
(
(
partial
(
ptr
.
choice
,
replace
=
True
),
partial
(
ptr
.
choice
,
replace
=
True
),
[
[
set_test_value
(
pt
.
dmatrix
(),
np
.
eye
(
5
,
dtype
=
np
.
float64
)),
(
pt
.
dmatrix
(),
np
.
eye
(
5
,
dtype
=
np
.
float64
)),
],
],
pt
.
as_tensor
([
2
]),
pt
.
as_tensor
([
2
]),
),
),
...
@@ -455,17 +451,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -455,17 +451,15 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a
,
p
=
p
,
size
=
size
,
replace
=
True
,
rng
=
rng
a
,
p
=
p
,
size
=
size
,
replace
=
True
,
rng
=
rng
),
),
[
[
set_test_value
(
pt
.
dmatrix
(),
np
.
eye
(
3
,
dtype
=
np
.
float64
)),
(
pt
.
dmatrix
(),
np
.
eye
(
3
,
dtype
=
np
.
float64
)),
set_test_value
(
(
pt
.
dvector
(),
np
.
array
([
0.25
,
0.5
,
0.25
],
dtype
=
np
.
float64
)),
pt
.
dvector
(),
np
.
array
([
0.25
,
0.5
,
0.25
],
dtype
=
np
.
float64
)
),
],
],
(
pt
.
as_tensor
([
2
,
3
])),
(
pt
.
as_tensor
([
2
,
3
])),
),
),
pytest
.
param
(
pytest
.
param
(
partial
(
ptr
.
choice
,
replace
=
False
),
partial
(
ptr
.
choice
,
replace
=
False
),
[
[
set_test_value
(
pt
.
dvector
(),
np
.
arange
(
5
,
dtype
=
np
.
float64
)),
(
pt
.
dvector
(),
np
.
arange
(
5
,
dtype
=
np
.
float64
)),
],
],
pt
.
as_tensor
([
2
]),
pt
.
as_tensor
([
2
]),
marks
=
pytest
.
mark
.
xfail
(
marks
=
pytest
.
mark
.
xfail
(
...
@@ -476,7 +470,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -476,7 +470,7 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
pytest
.
param
(
pytest
.
param
(
partial
(
ptr
.
choice
,
replace
=
False
),
partial
(
ptr
.
choice
,
replace
=
False
),
[
[
set_test_value
(
pt
.
dmatrix
(),
np
.
eye
(
5
,
dtype
=
np
.
float64
)),
(
pt
.
dmatrix
(),
np
.
eye
(
5
,
dtype
=
np
.
float64
)),
],
],
pt
.
as_tensor
([
2
]),
pt
.
as_tensor
([
2
]),
marks
=
pytest
.
mark
.
xfail
(
marks
=
pytest
.
mark
.
xfail
(
...
@@ -490,8 +484,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -490,8 +484,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a
,
p
=
p
,
size
=
size
,
replace
=
False
,
rng
=
rng
a
,
p
=
p
,
size
=
size
,
replace
=
False
,
rng
=
rng
),
),
[
[
set_test_value
(
pt
.
vector
(),
np
.
arange
(
5
,
dtype
=
np
.
float64
)),
(
pt
.
vector
(),
np
.
arange
(
5
,
dtype
=
np
.
float64
)),
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
0.5
,
0.0
,
0.25
,
0.0
,
0.25
],
dtype
=
np
.
float64
),
np
.
array
([
0.5
,
0.0
,
0.25
,
0.0
,
0.25
],
dtype
=
np
.
float64
),
),
),
...
@@ -504,10 +498,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -504,10 +498,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a
,
p
=
p
,
size
=
size
,
replace
=
False
,
rng
=
rng
a
,
p
=
p
,
size
=
size
,
replace
=
False
,
rng
=
rng
),
),
[
[
set_test_value
(
pt
.
dmatrix
(),
np
.
eye
(
3
,
dtype
=
np
.
float64
)),
(
pt
.
dmatrix
(),
np
.
eye
(
3
,
dtype
=
np
.
float64
)),
set_test_value
(
(
pt
.
dvector
(),
np
.
array
([
0.25
,
0.5
,
0.25
],
dtype
=
np
.
float64
)),
pt
.
dvector
(),
np
.
array
([
0.25
,
0.5
,
0.25
],
dtype
=
np
.
float64
)
),
],
],
(),
(),
),
),
...
@@ -517,10 +509,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -517,10 +509,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
a
,
p
=
p
,
size
=
size
,
replace
=
False
,
rng
=
rng
a
,
p
=
p
,
size
=
size
,
replace
=
False
,
rng
=
rng
),
),
[
[
set_test_value
(
pt
.
dmatrix
(),
np
.
eye
(
3
,
dtype
=
np
.
float64
)),
(
pt
.
dmatrix
(),
np
.
eye
(
3
,
dtype
=
np
.
float64
)),
set_test_value
(
(
pt
.
dvector
(),
np
.
array
([
0.25
,
0.5
,
0.25
],
dtype
=
np
.
float64
)),
pt
.
dvector
(),
np
.
array
([
0.25
,
0.5
,
0.25
],
dtype
=
np
.
float64
)
),
],
],
(
pt
.
as_tensor
([
2
,
1
])),
(
pt
.
as_tensor
([
2
,
1
])),
),
),
...
@@ -529,17 +519,14 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
...
@@ -529,17 +519,14 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
)
)
def
test_aligned_RandomVariable
(
rv_op
,
dist_args
,
size
):
def
test_aligned_RandomVariable
(
rv_op
,
dist_args
,
size
):
"""Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers."""
"""Tests for Numba samplers that are one-to-one with PyTensor's/NumPy's samplers."""
dist_args
,
test_dist_args
=
zip
(
*
dist_args
,
strict
=
True
)
rng
=
shared
(
np
.
random
.
default_rng
(
29402
))
rng
=
shared
(
np
.
random
.
default_rng
(
29402
))
g
=
rv_op
(
*
dist_args
,
size
=
size
,
rng
=
rng
)
g
=
rv_op
(
*
dist_args
,
size
=
size
,
rng
=
rng
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
dist_args
,
[
[
g
],
i
.
tag
.
test_value
test_dist_args
,
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
eval_obj_mode
=
False
,
# No python impl
eval_obj_mode
=
False
,
# No python impl
)
)
...
@@ -550,11 +537,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
...
@@ -550,11 +537,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
(
(
ptr
.
cauchy
,
ptr
.
cauchy
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -566,11 +553,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
...
@@ -566,11 +553,11 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
(
(
ptr
.
gumbel
,
ptr
.
gumbel
,
[
[
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
np
.
array
([
1.0
,
2.0
],
dtype
=
np
.
float64
),
),
),
set_test_value
(
(
pt
.
dscalar
(),
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
np
.
array
(
1.0
,
dtype
=
np
.
float64
),
),
),
...
@@ -583,18 +570,13 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
...
@@ -583,18 +570,13 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
)
)
def
test_unaligned_RandomVariable
(
rv_op
,
dist_args
,
base_size
,
cdf_name
,
params_conv
):
def
test_unaligned_RandomVariable
(
rv_op
,
dist_args
,
base_size
,
cdf_name
,
params_conv
):
"""Tests for Numba samplers that are not one-to-one with PyTensor's/NumPy's samplers."""
"""Tests for Numba samplers that are not one-to-one with PyTensor's/NumPy's samplers."""
dist_args
,
test_dist_args
=
zip
(
*
dist_args
,
strict
=
True
)
rng
=
shared
(
np
.
random
.
default_rng
(
29402
))
rng
=
shared
(
np
.
random
.
default_rng
(
29402
))
g
=
rv_op
(
*
dist_args
,
size
=
(
2000
,
*
base_size
),
rng
=
rng
)
g
=
rv_op
(
*
dist_args
,
size
=
(
2000
,
*
base_size
),
rng
=
rng
)
g_fn
=
function
(
dist_args
,
g
,
mode
=
numba_mode
)
g_fn
=
function
(
dist_args
,
g
,
mode
=
numba_mode
)
samples
=
g_fn
(
samples
=
g_fn
(
*
test_dist_args
)
*
[
i
.
tag
.
test_value
for
i
in
g_fn
.
maker
.
fgraph
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
]
)
bcast_dist_args
=
np
.
broadcast_arrays
(
*
[
i
.
tag
.
test_value
for
i
in
dist_args
]
)
bcast_dist_args
=
np
.
broadcast_arrays
(
*
test_dist_args
)
for
idx
in
np
.
ndindex
(
*
base_size
):
for
idx
in
np
.
ndindex
(
*
base_size
):
cdf_params
=
params_conv
(
*
(
arg
[
idx
]
for
arg
in
bcast_dist_args
))
cdf_params
=
params_conv
(
*
(
arg
[
idx
]
for
arg
in
bcast_dist_args
))
...
@@ -608,7 +590,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
...
@@ -608,7 +590,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
"a, size, cm"
,
"a, size, cm"
,
[
[
pytest
.
param
(
pytest
.
param
(
set_test_value
(
(
pt
.
dvector
(),
pt
.
dvector
(),
np
.
array
([
100000
,
1
,
1
],
dtype
=
np
.
float64
),
np
.
array
([
100000
,
1
,
1
],
dtype
=
np
.
float64
),
),
),
...
@@ -616,7 +598,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
...
@@ -616,7 +598,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
contextlib
.
suppress
(),
contextlib
.
suppress
(),
),
),
pytest
.
param
(
pytest
.
param
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
np
.
array
(
np
.
array
(
[[
100000
,
1
,
1
],
[
1
,
100000
,
1
],
[
1
,
1
,
100000
]],
[[
100000
,
1
,
1
],
[
1
,
100000
,
1
],
[
1
,
1
,
100000
]],
...
@@ -627,7 +609,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
...
@@ -627,7 +609,7 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
contextlib
.
suppress
(),
contextlib
.
suppress
(),
),
),
pytest
.
param
(
pytest
.
param
(
set_test_value
(
(
pt
.
dmatrix
(),
pt
.
dmatrix
(),
np
.
array
(
np
.
array
(
[[
100000
,
1
,
1
],
[
1
,
100000
,
1
],
[
1
,
1
,
100000
]],
[[
100000
,
1
,
1
],
[
1
,
100000
,
1
],
[
1
,
1
,
100000
]],
...
@@ -643,13 +625,12 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
...
@@ -643,13 +625,12 @@ def test_unaligned_RandomVariable(rv_op, dist_args, base_size, cdf_name, params_
],
],
)
)
def
test_DirichletRV
(
a
,
size
,
cm
):
def
test_DirichletRV
(
a
,
size
,
cm
):
a
,
a_val
=
a
rng
=
shared
(
np
.
random
.
default_rng
(
29402
))
rng
=
shared
(
np
.
random
.
default_rng
(
29402
))
g
=
ptr
.
dirichlet
(
a
,
size
=
size
,
rng
=
rng
)
g
=
ptr
.
dirichlet
(
a
,
size
=
size
,
rng
=
rng
)
g_fn
=
function
([
a
],
g
,
mode
=
numba_mode
)
g_fn
=
function
([
a
],
g
,
mode
=
numba_mode
)
with
cm
:
with
cm
:
a_val
=
a
.
tag
.
test_value
all_samples
=
[]
all_samples
=
[]
for
i
in
range
(
1000
):
for
i
in
range
(
1000
):
samples
=
g_fn
(
a_val
)
samples
=
g_fn
(
a_val
)
...
...
tests/link/numba/test_scalar.py
浏览文件 @
cc8c4992
...
@@ -5,13 +5,10 @@ import pytensor.scalar as ps
...
@@ -5,13 +5,10 @@ import pytensor.scalar as ps
import
pytensor.scalar.basic
as
psb
import
pytensor.scalar.basic
as
psb
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
config
from
pytensor
import
config
from
pytensor.compile.sharedvalue
import
SharedVariable
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.scalar.basic
import
Composite
from
pytensor.scalar.basic
import
Composite
from
pytensor.tensor
import
tensor
from
pytensor.tensor
import
tensor
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.elemwise
import
Elemwise
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
set_test_value
from
tests.link.numba.test_basic
import
compare_numba_and_py
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
...
@@ -21,48 +18,43 @@ rng = np.random.default_rng(42849)
...
@@ -21,48 +18,43 @@ rng = np.random.default_rng(42849)
"x, y"
,
"x, y"
,
[
[
(
(
set_test_value
(
pt
.
lvector
(),
np
.
arange
(
4
,
dtype
=
"int64"
)),
(
pt
.
lvector
(),
np
.
arange
(
4
,
dtype
=
"int64"
)),
set_test_value
(
pt
.
dvector
(),
np
.
arange
(
4
,
dtype
=
"float64"
)),
(
pt
.
dvector
(),
np
.
arange
(
4
,
dtype
=
"float64"
)),
),
),
(
(
set_test_value
(
pt
.
dmatrix
(),
np
.
arange
(
4
,
dtype
=
"float64"
)
.
reshape
((
2
,
2
))),
(
pt
.
dmatrix
(),
np
.
arange
(
4
,
dtype
=
"float64"
)
.
reshape
((
2
,
2
))),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
4
,
dtype
=
"int64"
)),
(
pt
.
lscalar
(),
np
.
array
(
4
,
dtype
=
"int64"
)),
),
),
],
],
)
)
def
test_Second
(
x
,
y
):
def
test_Second
(
x
,
y
):
x
,
x_test
=
x
y
,
y_test
=
y
# We use the `Elemwise`-wrapped version of `Second`
# We use the `Elemwise`-wrapped version of `Second`
g
=
pt
.
second
(
x
,
y
)
g
=
pt
.
second
(
x
,
y
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
x
,
y
],
[
g
,
i
.
tag
.
test_value
[
x_test
,
y_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"v, min, max"
,
"v, min, max"
,
[
[
(
set_test_value
(
pt
.
scalar
(),
np
.
array
(
10
,
dtype
=
config
.
floatX
)),
3.0
,
7.0
),
((
pt
.
scalar
(),
np
.
array
(
10
,
dtype
=
config
.
floatX
)),
3.0
,
7.0
),
(
set_test_value
(
pt
.
scalar
(),
np
.
array
(
1
,
dtype
=
config
.
floatX
)),
3.0
,
7.0
),
((
pt
.
scalar
(),
np
.
array
(
1
,
dtype
=
config
.
floatX
)),
3.0
,
7.0
),
(
set_test_value
(
pt
.
scalar
(),
np
.
array
(
10
,
dtype
=
config
.
floatX
)),
7.0
,
3.0
),
((
pt
.
scalar
(),
np
.
array
(
10
,
dtype
=
config
.
floatX
)),
7.0
,
3.0
),
],
],
)
)
def
test_Clip
(
v
,
min
,
max
):
def
test_Clip
(
v
,
min
,
max
):
v
,
v_test
=
v
g
=
ps
.
clip
(
v
,
min
,
max
)
g
=
ps
.
clip
(
v
,
min
,
max
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
[
g
],
i
.
tag
.
test_value
[
v_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -100,46 +92,39 @@ def test_Clip(v, min, max):
...
@@ -100,46 +92,39 @@ def test_Clip(v, min, max):
def
test_Composite
(
inputs
,
input_values
,
scalar_fn
):
def
test_Composite
(
inputs
,
input_values
,
scalar_fn
):
composite_inputs
=
[
ps
.
ScalarType
(
config
.
floatX
)(
name
=
i
.
name
)
for
i
in
inputs
]
composite_inputs
=
[
ps
.
ScalarType
(
config
.
floatX
)(
name
=
i
.
name
)
for
i
in
inputs
]
comp_op
=
Elemwise
(
Composite
(
composite_inputs
,
[
scalar_fn
(
*
composite_inputs
)]))
comp_op
=
Elemwise
(
Composite
(
composite_inputs
,
[
scalar_fn
(
*
composite_inputs
)]))
out_fg
=
FunctionGraph
(
inputs
,
[
comp_op
(
*
inputs
)])
compare_numba_and_py
(
inputs
,
[
comp_op
(
*
inputs
)],
input_values
)
compare_numba_and_py
(
out_fg
,
input_values
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"v, dtype"
,
"v, dtype"
,
[
[
(
set_test_value
(
pt
.
fscalar
(),
np
.
array
(
1.0
,
dtype
=
"float32"
)),
psb
.
float64
),
((
pt
.
fscalar
(),
np
.
array
(
1.0
,
dtype
=
"float32"
)),
psb
.
float64
),
(
set_test_value
(
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
"float64"
)),
psb
.
float32
),
((
pt
.
dscalar
(),
np
.
array
(
1.0
,
dtype
=
"float64"
)),
psb
.
float32
),
],
],
)
)
def
test_Cast
(
v
,
dtype
):
def
test_Cast
(
v
,
dtype
):
v
,
v_test
=
v
g
=
psb
.
Cast
(
dtype
)(
v
)
g
=
psb
.
Cast
(
dtype
)(
v
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
[
g
],
i
.
tag
.
test_value
[
v_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"v, dtype"
,
"v, dtype"
,
[
[
(
set_test_value
(
pt
.
iscalar
(),
np
.
array
(
10
,
dtype
=
"int32"
)),
psb
.
float64
),
((
pt
.
iscalar
(),
np
.
array
(
10
,
dtype
=
"int32"
)),
psb
.
float64
),
],
],
)
)
def
test_reciprocal
(
v
,
dtype
):
def
test_reciprocal
(
v
,
dtype
):
v
,
v_test
=
v
g
=
psb
.
reciprocal
(
v
)
g
=
psb
.
reciprocal
(
v
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
[
g
],
i
.
tag
.
test_value
[
v_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -156,6 +141,7 @@ def test_isnan(composite):
...
@@ -156,6 +141,7 @@ def test_isnan(composite):
out
=
pt
.
isnan
(
x
)
out
=
pt
.
isnan
(
x
)
compare_numba_and_py
(
compare_numba_and_py
(
([
x
],
[
out
]),
[
x
],
[
out
],
[
np
.
array
([
1
,
0
],
dtype
=
"float64"
)],
[
np
.
array
([
1
,
0
],
dtype
=
"float64"
)],
)
)
tests/link/numba/test_scan.py
浏览文件 @
cc8c4992
...
@@ -5,7 +5,6 @@ import pytensor
...
@@ -5,7 +5,6 @@ import pytensor
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
import
config
,
function
,
grad
from
pytensor
import
config
,
function
,
grad
from
pytensor.compile.mode
import
Mode
,
get_mode
from
pytensor.compile.mode
import
Mode
,
get_mode
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.scalar
import
Log1p
from
pytensor.scalar
import
Log1p
from
pytensor.scan.basic
import
scan
from
pytensor.scan.basic
import
scan
from
pytensor.scan.op
import
Scan
from
pytensor.scan.op
import
Scan
...
@@ -147,7 +146,7 @@ def test_xit_xot_types(
...
@@ -147,7 +146,7 @@ def test_xit_xot_types(
if
output_vals
is
None
:
if
output_vals
is
None
:
compare_numba_and_py
(
compare_numba_and_py
(
(
sequences
+
non_sequences
,
res
)
,
input_vals
,
updates
=
updates
sequences
+
non_sequences
,
res
,
input_vals
,
updates
=
updates
)
)
else
:
else
:
numba_mode
=
get_mode
(
"NUMBA"
)
numba_mode
=
get_mode
(
"NUMBA"
)
...
@@ -217,10 +216,7 @@ def test_scan_multiple_output(benchmark):
...
@@ -217,10 +216,7 @@ def test_scan_multiple_output(benchmark):
logp_c_all
.
name
=
"C_t_logp"
logp_c_all
.
name
=
"C_t_logp"
logp_d_all
.
name
=
"D_t_logp"
logp_d_all
.
name
=
"D_t_logp"
out_fg
=
FunctionGraph
(
out
=
[
st
,
et
,
it
,
logp_c_all
,
logp_d_all
]
[
pt_C
,
pt_D
,
st0
,
et0
,
it0
,
logp_c
,
logp_d
,
beta
,
gamma
,
delta
],
[
st
,
et
,
it
,
logp_c_all
,
logp_d_all
],
)
s0
,
e0
,
i0
=
100
,
50
,
25
s0
,
e0
,
i0
=
100
,
50
,
25
logp_c0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
logp_c0
=
np
.
array
(
0.0
,
dtype
=
config
.
floatX
)
...
@@ -243,21 +239,21 @@ def test_scan_multiple_output(benchmark):
...
@@ -243,21 +239,21 @@ def test_scan_multiple_output(benchmark):
gamma_val
,
gamma_val
,
delta_val
,
delta_val
,
]
]
scan_fn
,
_
=
compare_numba_and_py
(
out_fg
,
test_input_vals
)
scan_fn
,
_
=
compare_numba_and_py
(
[
pt_C
,
pt_D
,
st0
,
et0
,
it0
,
logp_c
,
logp_d
,
beta
,
gamma
,
delta
],
out
,
test_input_vals
,
)
benchmark
(
scan_fn
,
*
test_input_vals
)
benchmark
(
scan_fn
,
*
test_input_vals
)
@config.change_flags
(
compute_test_value
=
"raise"
)
def
test_scan_tap_output
():
def
test_scan_tap_output
():
a_pt
=
pt
.
scalar
(
"a"
)
a_pt
=
pt
.
scalar
(
"a"
)
a_pt
.
tag
.
test_value
=
10.0
b_pt
=
pt
.
arange
(
11
)
.
astype
(
config
.
floatX
)
b_pt
=
pt
.
vector
(
"b"
)
b_pt
.
name
=
"b"
c_pt
=
pt
.
arange
(
20
,
31
,
dtype
=
config
.
floatX
)
c_pt
=
pt
.
vector
(
"c"
)
c_pt
.
name
=
"c"
def
input_step_fn
(
b
,
b2
,
c
,
x_tm1
,
y_tm1
,
y_tm3
,
a
):
def
input_step_fn
(
b
,
b2
,
c
,
x_tm1
,
y_tm1
,
y_tm3
,
a
):
x_tm1
.
name
=
"x_tm1"
x_tm1
.
name
=
"x_tm1"
...
@@ -301,14 +297,12 @@ def test_scan_tap_output():
...
@@ -301,14 +297,12 @@ def test_scan_tap_output():
strict
=
True
,
strict
=
True
,
)
)
out_fg
=
FunctionGraph
([
a_pt
,
b_pt
,
c_pt
],
scan_res
)
test_input_vals
=
[
test_input_vals
=
[
np
.
array
(
10.0
)
.
astype
(
config
.
floatX
),
np
.
array
(
10.0
)
.
astype
(
config
.
floatX
),
np
.
arange
(
11
,
dtype
=
config
.
floatX
),
np
.
arange
(
11
,
dtype
=
config
.
floatX
),
np
.
arange
(
20
,
31
,
dtype
=
config
.
floatX
),
np
.
arange
(
20
,
31
,
dtype
=
config
.
floatX
),
]
]
compare_numba_and_py
(
out_fg
,
test_input_vals
)
compare_numba_and_py
(
[
a_pt
,
b_pt
,
c_pt
],
scan_res
,
test_input_vals
)
def
test_scan_while
():
def
test_scan_while
():
...
@@ -323,12 +317,10 @@ def test_scan_while():
...
@@ -323,12 +317,10 @@ def test_scan_while():
n_steps
=
1024
,
n_steps
=
1024
,
)
)
out_fg
=
FunctionGraph
([
max_value
],
[
values
])
test_input_vals
=
[
test_input_vals
=
[
np
.
array
(
45
)
.
astype
(
config
.
floatX
),
np
.
array
(
45
)
.
astype
(
config
.
floatX
),
]
]
compare_numba_and_py
(
out_fg
,
test_input_vals
)
compare_numba_and_py
(
[
max_value
],
[
values
]
,
test_input_vals
)
def
test_scan_multiple_none_output
():
def
test_scan_multiple_none_output
():
...
@@ -343,11 +335,8 @@ def test_scan_multiple_none_output():
...
@@ -343,11 +335,8 @@ def test_scan_multiple_none_output():
outputs_info
=
[
pt
.
ones_like
(
A
),
None
,
None
],
outputs_info
=
[
pt
.
ones_like
(
A
),
None
,
None
],
n_steps
=
3
,
n_steps
=
3
,
)
)
out_fg
=
FunctionGraph
([
A
],
result
)
test_input_vals
=
(
np
.
array
([
1.0
,
2.0
]),)
test_input_vals
=
(
np
.
array
([
1.0
,
2.0
]),)
compare_numba_and_py
([
A
],
result
,
test_input_vals
)
compare_numba_and_py
(
out_fg
,
test_input_vals
)
@pytest.mark.parametrize
(
"n_steps_val"
,
[
1
,
5
])
@pytest.mark.parametrize
(
"n_steps_val"
,
[
1
,
5
])
...
@@ -372,11 +361,14 @@ def test_scan_save_mem_basic(n_steps_val):
...
@@ -372,11 +361,14 @@ def test_scan_save_mem_basic(n_steps_val):
numba_mode
=
get_mode
(
"NUMBA"
)
.
including
(
"scan_save_mem"
)
numba_mode
=
get_mode
(
"NUMBA"
)
.
including
(
"scan_save_mem"
)
py_mode
=
Mode
(
"py"
)
.
including
(
"scan_save_mem"
)
py_mode
=
Mode
(
"py"
)
.
including
(
"scan_save_mem"
)
out_fg
=
FunctionGraph
([
init_x
,
n_steps
],
[
output
])
test_input_vals
=
(
state_val
,
n_steps_val
)
test_input_vals
=
(
state_val
,
n_steps_val
)
compare_numba_and_py
(
compare_numba_and_py
(
out_fg
,
test_input_vals
,
numba_mode
=
numba_mode
,
py_mode
=
py_mode
[
init_x
,
n_steps
],
[
output
],
test_input_vals
,
numba_mode
=
numba_mode
,
py_mode
=
py_mode
,
)
)
...
@@ -410,14 +402,12 @@ def test_mitmots_basic():
...
@@ -410,14 +402,12 @@ def test_mitmots_basic():
numba_mode
=
get_mode
(
"NUMBA"
)
.
including
(
"scan_save_mem"
)
numba_mode
=
get_mode
(
"NUMBA"
)
.
including
(
"scan_save_mem"
)
py_mode
=
Mode
(
"py"
)
.
including
(
"scan_save_mem"
)
py_mode
=
Mode
(
"py"
)
.
including
(
"scan_save_mem"
)
out_fg
=
FunctionGraph
([
seq
,
init_x
],
g_outs
)
seq_val
=
np
.
arange
(
3
)
seq_val
=
np
.
arange
(
3
)
init_x_val
=
np
.
r_
[
-
2
,
-
1
]
init_x_val
=
np
.
r_
[
-
2
,
-
1
]
test_input_vals
=
(
seq_val
,
init_x_val
)
test_input_vals
=
(
seq_val
,
init_x_val
)
compare_numba_and_py
(
compare_numba_and_py
(
out_fg
,
test_input_vals
,
numba_mode
=
numba_mode
,
py_mode
=
py_mode
[
seq
,
init_x
],
g_outs
,
test_input_vals
,
numba_mode
=
numba_mode
,
py_mode
=
py_mode
)
)
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
cc8c4992
...
@@ -9,14 +9,14 @@ from scipy import linalg as scipy_linalg
...
@@ -9,14 +9,14 @@ from scipy import linalg as scipy_linalg
import
pytensor
import
pytensor
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor
.graph
import
FunctionGraph
from
pytensor
import
config
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
from
tests.link.numba.test_basic
import
compare_numba_and_py
from
tests.link.numba.test_basic
import
compare_numba_and_py
numba
=
pytest
.
importorskip
(
"numba"
)
numba
=
pytest
.
importorskip
(
"numba"
)
floatX
=
pytensor
.
config
.
floatX
floatX
=
config
.
floatX
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
...
@@ -88,7 +88,12 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
...
@@ -88,7 +88,12 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
np
.
testing
.
assert_allclose
(
test_input
@
X_np
,
b_val
,
atol
=
ATOL
,
rtol
=
RTOL
)
np
.
testing
.
assert_allclose
(
test_input
@
X_np
,
b_val
,
atol
=
ATOL
,
rtol
=
RTOL
)
compare_numba_and_py
(
f
.
maker
.
fgraph
,
[
A_func
(
A_val
.
copy
()),
b_val
.
copy
()])
compiled_fgraph
=
f
.
maker
.
fgraph
compare_numba_and_py
(
compiled_fgraph
.
inputs
,
compiled_fgraph
.
outputs
,
[
A_func
(
A_val
.
copy
()),
b_val
.
copy
()],
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -159,12 +164,10 @@ def test_numba_Cholesky(lower, trans):
...
@@ -159,12 +164,10 @@ def test_numba_Cholesky(lower, trans):
cov_
=
cov
cov_
=
cov
chol
=
pt
.
linalg
.
cholesky
(
cov_
,
lower
=
lower
)
chol
=
pt
.
linalg
.
cholesky
(
cov_
,
lower
=
lower
)
fg
=
FunctionGraph
(
outputs
=
[
chol
])
x
=
np
.
array
([
0.1
,
0.2
,
0.3
])
.
astype
(
floatX
)
x
=
np
.
array
([
0.1
,
0.2
,
0.3
])
.
astype
(
floatX
)
val
=
np
.
eye
(
3
)
.
astype
(
floatX
)
+
x
[
None
,
:]
*
x
[:,
None
]
val
=
np
.
eye
(
3
)
.
astype
(
floatX
)
+
x
[
None
,
:]
*
x
[:,
None
]
compare_numba_and_py
(
fg
,
[
val
])
compare_numba_and_py
(
[
cov
],
[
chol
]
,
[
val
])
def
test_numba_Cholesky_raises_on_nan_input
():
def
test_numba_Cholesky_raises_on_nan_input
():
...
@@ -218,8 +221,7 @@ def test_block_diag():
...
@@ -218,8 +221,7 @@ def test_block_diag():
B_val
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
.
astype
(
floatX
)
B_val
=
np
.
random
.
normal
(
size
=
(
3
,
3
))
.
astype
(
floatX
)
C_val
=
np
.
random
.
normal
(
size
=
(
2
,
2
))
.
astype
(
floatX
)
C_val
=
np
.
random
.
normal
(
size
=
(
2
,
2
))
.
astype
(
floatX
)
D_val
=
np
.
random
.
normal
(
size
=
(
4
,
4
))
.
astype
(
floatX
)
D_val
=
np
.
random
.
normal
(
size
=
(
4
,
4
))
.
astype
(
floatX
)
out_fg
=
pytensor
.
graph
.
FunctionGraph
([
A
,
B
,
C
,
D
],
[
X
])
compare_numba_and_py
([
A
,
B
,
C
,
D
],
[
X
],
[
A_val
,
B_val
,
C_val
,
D_val
])
compare_numba_and_py
(
out_fg
,
[
A_val
,
B_val
,
C_val
,
D_val
])
def
test_lamch
():
def
test_lamch
():
...
@@ -390,7 +392,7 @@ def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
...
@@ -390,7 +392,7 @@ def test_solve(b_shape: tuple[int], assume_a: Literal["gen", "sym", "pos"]):
)
)
op
=
f
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
op
=
f
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
compare_numba_and_py
(
([
A
,
b
],
[
X
]),
inputs
=
[
A_val
,
b_val
],
inplace
=
True
)
compare_numba_and_py
(
[
A
,
b
],
[
X
],
test_
inputs
=
[
A_val
,
b_val
],
inplace
=
True
)
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
# Calling this is destructive and will rewrite b_val to be the answer. Store copies of the inputs first.
A_val_copy
=
A_val
.
copy
()
A_val_copy
=
A_val
.
copy
()
...
...
tests/link/numba/test_sparse.py
浏览文件 @
cc8c4992
...
@@ -100,4 +100,4 @@ def test_sparse_objmode():
...
@@ -100,4 +100,4 @@ def test_sparse_objmode():
UserWarning
,
UserWarning
,
match
=
"Numba will use object mode to run SparseDot's perform method"
,
match
=
"Numba will use object mode to run SparseDot's perform method"
,
):
):
compare_numba_and_py
(
((
x
,
y
),
(
out
,))
,
[
x_val
,
y_val
])
compare_numba_and_py
(
[
x
,
y
],
out
,
[
x_val
,
y_val
])
tests/link/numba/test_subtensor.py
浏览文件 @
cc8c4992
...
@@ -4,7 +4,6 @@ import numpy as np
...
@@ -4,7 +4,6 @@ import numpy as np
import
pytest
import
pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.graph
import
FunctionGraph
from
pytensor.tensor
import
as_tensor
from
pytensor.tensor
import
as_tensor
from
pytensor.tensor.subtensor
import
(
from
pytensor.tensor.subtensor
import
(
AdvancedIncSubtensor
,
AdvancedIncSubtensor
,
...
@@ -44,8 +43,7 @@ def test_Subtensor(x, indices):
...
@@ -44,8 +43,7 @@ def test_Subtensor(x, indices):
"""Test NumPy's basic indexing."""
"""Test NumPy's basic indexing."""
out_pt
=
x
[
indices
]
out_pt
=
x
[
indices
]
assert
isinstance
(
out_pt
.
owner
.
op
,
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
Subtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_numba_and_py
([],
[
out_pt
],
[])
compare_numba_and_py
(
out_fg
,
[])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -59,16 +57,14 @@ def test_AdvancedSubtensor1(x, indices):
...
@@ -59,16 +57,14 @@ def test_AdvancedSubtensor1(x, indices):
"""Test NumPy's advanced indexing in one dimension."""
"""Test NumPy's advanced indexing in one dimension."""
out_pt
=
advanced_subtensor1
(
x
,
*
indices
)
out_pt
=
advanced_subtensor1
(
x
,
*
indices
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedSubtensor1
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedSubtensor1
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_numba_and_py
([],
[
out_pt
],
[])
compare_numba_and_py
(
out_fg
,
[])
def
test_AdvancedSubtensor1_out_of_bounds
():
def
test_AdvancedSubtensor1_out_of_bounds
():
out_pt
=
advanced_subtensor1
(
np
.
arange
(
3
),
[
4
])
out_pt
=
advanced_subtensor1
(
np
.
arange
(
3
),
[
4
])
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedSubtensor1
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedSubtensor1
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
with
pytest
.
raises
(
IndexError
):
with
pytest
.
raises
(
IndexError
):
compare_numba_and_py
(
out_fg
,
[])
compare_numba_and_py
(
[],
[
out_pt
]
,
[])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -151,7 +147,6 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
...
@@ -151,7 +147,6 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
x_pt
=
x
.
type
()
x_pt
=
x
.
type
()
out_pt
=
x_pt
[
indices
]
out_pt
=
x_pt
[
indices
]
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
with
(
with
(
pytest
.
warns
(
pytest
.
warns
(
UserWarning
,
UserWarning
,
...
@@ -161,7 +156,8 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
...
@@ -161,7 +156,8 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
else
contextlib
.
nullcontext
()
else
contextlib
.
nullcontext
()
):
):
compare_numba_and_py
(
compare_numba_and_py
(
out_fg
,
[
x_pt
],
[
out_pt
],
[
x
.
data
],
[
x
.
data
],
numba_mode
=
numba_mode
.
including
(
"specialize"
),
numba_mode
=
numba_mode
.
including
(
"specialize"
),
)
)
...
@@ -195,19 +191,16 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
...
@@ -195,19 +191,16 @@ def test_AdvancedSubtensor(x, indices, objmode_needed):
def
test_IncSubtensor
(
x
,
y
,
indices
):
def
test_IncSubtensor
(
x
,
y
,
indices
):
out_pt
=
set_subtensor
(
x
[
indices
],
y
)
out_pt
=
set_subtensor
(
x
[
indices
],
y
)
assert
isinstance
(
out_pt
.
owner
.
op
,
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_numba_and_py
([],
[
out_pt
],
[])
compare_numba_and_py
(
out_fg
,
[])
out_pt
=
inc_subtensor
(
x
[
indices
],
y
)
out_pt
=
inc_subtensor
(
x
[
indices
],
y
)
assert
isinstance
(
out_pt
.
owner
.
op
,
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
IncSubtensor
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_numba_and_py
([],
[
out_pt
],
[])
compare_numba_and_py
(
out_fg
,
[])
x_pt
=
x
.
type
()
x_pt
=
x
.
type
()
out_pt
=
set_subtensor
(
x_pt
[
indices
],
y
,
inplace
=
True
)
out_pt
=
set_subtensor
(
x_pt
[
indices
],
y
,
inplace
=
True
)
assert
isinstance
(
out_pt
.
owner
.
op
,
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
IncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_numba_and_py
([
x_pt
],
[
out_pt
],
[
x
.
data
])
compare_numba_and_py
(
out_fg
,
[
x
.
data
])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -249,13 +242,11 @@ def test_IncSubtensor(x, y, indices):
...
@@ -249,13 +242,11 @@ def test_IncSubtensor(x, y, indices):
def
test_AdvancedIncSubtensor1
(
x
,
y
,
indices
):
def
test_AdvancedIncSubtensor1
(
x
,
y
,
indices
):
out_pt
=
advanced_set_subtensor1
(
x
,
y
,
*
indices
)
out_pt
=
advanced_set_subtensor1
(
x
,
y
,
*
indices
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor1
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor1
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_numba_and_py
([],
[
out_pt
],
[])
compare_numba_and_py
(
out_fg
,
[])
out_pt
=
advanced_inc_subtensor1
(
x
,
y
,
*
indices
)
out_pt
=
advanced_inc_subtensor1
(
x
,
y
,
*
indices
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor1
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor1
)
out_fg
=
FunctionGraph
([],
[
out_pt
])
compare_numba_and_py
([],
[
out_pt
],
[])
compare_numba_and_py
(
out_fg
,
[])
# With symbolic inputs
# With symbolic inputs
x_pt
=
x
.
type
()
x_pt
=
x
.
type
()
...
@@ -263,15 +254,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
...
@@ -263,15 +254,13 @@ def test_AdvancedIncSubtensor1(x, y, indices):
out_pt
=
AdvancedIncSubtensor1
(
inplace
=
True
)(
x_pt
,
y_pt
,
*
indices
)
out_pt
=
AdvancedIncSubtensor1
(
inplace
=
True
)(
x_pt
,
y_pt
,
*
indices
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor1
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor1
)
out_fg
=
FunctionGraph
([
x_pt
,
y_pt
],
[
out_pt
])
compare_numba_and_py
([
x_pt
,
y_pt
],
[
out_pt
],
[
x
.
data
,
y
.
data
])
compare_numba_and_py
(
out_fg
,
[
x
.
data
,
y
.
data
])
out_pt
=
AdvancedIncSubtensor1
(
set_instead_of_inc
=
True
,
inplace
=
True
)(
out_pt
=
AdvancedIncSubtensor1
(
set_instead_of_inc
=
True
,
inplace
=
True
)(
x_pt
,
y_pt
,
*
indices
x_pt
,
y_pt
,
*
indices
)
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor1
)
assert
isinstance
(
out_pt
.
owner
.
op
,
AdvancedIncSubtensor1
)
out_fg
=
FunctionGraph
([
x_pt
,
y_pt
],
[
out_pt
])
compare_numba_and_py
([
x_pt
,
y_pt
],
[
out_pt
],
[
x
.
data
,
y
.
data
])
compare_numba_and_py
(
out_fg
,
[
x
.
data
,
y
.
data
])
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -454,7 +443,7 @@ def test_AdvancedIncSubtensor(
...
@@ -454,7 +443,7 @@ def test_AdvancedIncSubtensor(
if
set_requires_objmode
if
set_requires_objmode
else
contextlib
.
nullcontext
()
else
contextlib
.
nullcontext
()
):
):
fn
,
_
=
compare_numba_and_py
(
([
x_pt
,
y_pt
],
[
out_pt
])
,
[
x
,
y
],
numba_mode
=
mode
)
fn
,
_
=
compare_numba_and_py
(
[
x_pt
,
y_pt
],
out_pt
,
[
x
,
y
],
numba_mode
=
mode
)
if
inplace
:
if
inplace
:
# Test updates inplace
# Test updates inplace
...
@@ -474,7 +463,7 @@ def test_AdvancedIncSubtensor(
...
@@ -474,7 +463,7 @@ def test_AdvancedIncSubtensor(
if
inc_requires_objmode
if
inc_requires_objmode
else
contextlib
.
nullcontext
()
else
contextlib
.
nullcontext
()
):
):
fn
,
_
=
compare_numba_and_py
(
([
x_pt
,
y_pt
],
[
out_pt
])
,
[
x
,
y
],
numba_mode
=
mode
)
fn
,
_
=
compare_numba_and_py
(
[
x_pt
,
y_pt
],
out_pt
,
[
x
,
y
],
numba_mode
=
mode
)
if
inplace
:
if
inplace
:
# Test updates inplace
# Test updates inplace
x_orig
=
x
.
copy
()
x_orig
=
x
.
copy
()
...
...
tests/link/numba/test_tensor_basic.py
浏览文件 @
cc8c4992
...
@@ -6,15 +6,11 @@ import pytensor.tensor as pt
...
@@ -6,15 +6,11 @@ import pytensor.tensor as pt
import
pytensor.tensor.basic
as
ptb
import
pytensor.tensor.basic
as
ptb
from
pytensor
import
config
,
function
from
pytensor
import
config
,
function
from
pytensor.compile
import
get_mode
from
pytensor.compile
import
get_mode
from
pytensor.compile.sharedvalue
import
SharedVariable
from
pytensor.graph.basic
import
Constant
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.scalar
import
Add
from
pytensor.scalar
import
Add
from
pytensor.tensor.shape
import
Unbroadcast
from
pytensor.tensor.shape
import
Unbroadcast
from
tests.link.numba.test_basic
import
(
from
tests.link.numba.test_basic
import
(
compare_numba_and_py
,
compare_numba_and_py
,
compare_shape_dtype
,
compare_shape_dtype
,
set_test_value
,
)
)
from
tests.tensor.test_basic
import
check_alloc_runtime_broadcast
from
tests.tensor.test_basic
import
check_alloc_runtime_broadcast
...
@@ -31,21 +27,18 @@ rng = np.random.default_rng(42849)
...
@@ -31,21 +27,18 @@ rng = np.random.default_rng(42849)
[
[
(
0.0
,
(
2
,
3
)),
(
0.0
,
(
2
,
3
)),
(
1.1
,
(
2
,
3
)),
(
1.1
,
(
2
,
3
)),
(
set_test_value
(
pt
.
scalar
(
"a"
),
np
.
array
(
10.0
,
dtype
=
config
.
floatX
)),
(
20
,)),
((
pt
.
scalar
(
"a"
),
np
.
array
(
10.0
,
dtype
=
config
.
floatX
)),
(
20
,)),
(
set_test_value
(
pt
.
vector
(
"a"
),
np
.
ones
(
10
,
dtype
=
config
.
floatX
)),
(
20
,
10
)),
((
pt
.
vector
(
"a"
),
np
.
ones
(
10
,
dtype
=
config
.
floatX
)),
(
20
,
10
)),
],
],
)
)
def
test_Alloc
(
v
,
shape
):
def
test_Alloc
(
v
,
shape
):
v
,
v_test
=
v
if
isinstance
(
v
,
tuple
)
else
(
v
,
None
)
g
=
pt
.
alloc
(
v
,
*
shape
)
g
=
pt
.
alloc
(
v
,
*
shape
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
_
,
(
numba_res
,)
=
compare_numba_and_py
(
_
,
(
numba_res
,)
=
compare_numba_and_py
(
g_fg
,
[
v
]
if
v_test
is
not
None
else
[],
[
[
g
],
i
.
tag
.
test_value
[
v_test
]
if
v_test
is
not
None
else
[],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
assert
numba_res
.
shape
==
shape
assert
numba_res
.
shape
==
shape
...
@@ -57,58 +50,38 @@ def test_alloc_runtime_broadcast():
...
@@ -57,58 +50,38 @@ def test_alloc_runtime_broadcast():
def
test_AllocEmpty
():
def
test_AllocEmpty
():
x
=
pt
.
empty
((
2
,
3
),
dtype
=
"float32"
)
x
=
pt
.
empty
((
2
,
3
),
dtype
=
"float32"
)
x_fg
=
FunctionGraph
([],
[
x
])
# We cannot compare the values in the arrays, only the shapes and dtypes
# We cannot compare the values in the arrays, only the shapes and dtypes
compare_numba_and_py
(
x_fg
,
[],
assert_fn
=
compare_shape_dtype
)
compare_numba_and_py
(
[],
x
,
[],
assert_fn
=
compare_shape_dtype
)
@pytest.mark.parametrize
(
def
test_TensorFromScalar
():
"v"
,
[
set_test_value
(
ps
.
float64
(),
np
.
array
(
1.0
,
dtype
=
"float64"
))]
v
,
v_test
=
ps
.
float64
(),
np
.
array
(
1.0
,
dtype
=
"float64"
)
)
def
test_TensorFromScalar
(
v
):
g
=
ptb
.
TensorFromScalar
()(
v
)
g
=
ptb
.
TensorFromScalar
()(
v
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
g
,
i
.
tag
.
test_value
[
v_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
@pytest.mark.parametrize
(
def
test_ScalarFromTensor
():
"v"
,
v
,
v_test
=
pt
.
scalar
(),
np
.
array
(
1.0
,
dtype
=
config
.
floatX
)
[
set_test_value
(
pt
.
scalar
(),
np
.
array
(
1.0
,
dtype
=
config
.
floatX
)),
],
)
def
test_ScalarFromTensor
(
v
):
g
=
ptb
.
ScalarFromTensor
()(
v
)
g
=
ptb
.
ScalarFromTensor
()(
v
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
g
,
i
.
tag
.
test_value
[
v_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
def
test_Unbroadcast
():
def
test_Unbroadcast
():
v
=
set_test_value
(
pt
.
row
(),
np
.
array
([[
1.0
,
2.0
]],
dtype
=
config
.
floatX
)
)
v
,
v_test
=
pt
.
row
(),
np
.
array
([[
1.0
,
2.0
]],
dtype
=
config
.
floatX
)
g
=
Unbroadcast
(
0
)(
v
)
g
=
Unbroadcast
(
0
)(
v
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
v
],
[
g
,
i
.
tag
.
test_value
[
v_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -117,65 +90,52 @@ def test_Unbroadcast():
...
@@ -117,65 +90,52 @@ def test_Unbroadcast():
[
[
(
(
(
(
set_test_value
(
pt
.
scalar
(),
np
.
array
(
1
,
dtype
=
config
.
floatX
)),
(
pt
.
scalar
(),
np
.
array
(
1
,
dtype
=
config
.
floatX
)),
set_test_value
(
pt
.
scalar
(),
np
.
array
(
2
,
dtype
=
config
.
floatX
)),
(
pt
.
scalar
(),
np
.
array
(
2
,
dtype
=
config
.
floatX
)),
set_test_value
(
pt
.
scalar
(),
np
.
array
(
3
,
dtype
=
config
.
floatX
)),
(
pt
.
scalar
(),
np
.
array
(
3
,
dtype
=
config
.
floatX
)),
),
),
config
.
floatX
,
config
.
floatX
,
),
),
(
(
(
(
set_test_value
(
pt
.
dscalar
(),
np
.
array
(
1
,
dtype
=
np
.
float64
)),
(
pt
.
dscalar
(),
np
.
array
(
1
,
dtype
=
np
.
float64
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
3
,
dtype
=
np
.
int32
)),
(
pt
.
lscalar
(),
np
.
array
(
3
,
dtype
=
np
.
int32
)),
),
),
"float64"
,
"float64"
,
),
),
(
(
(
set_test_value
(
pt
.
iscalar
(),
np
.
array
(
1
,
dtype
=
np
.
int32
)),),
((
pt
.
iscalar
(),
np
.
array
(
1
,
dtype
=
np
.
int32
)),),
"float64"
,
"float64"
,
),
),
(
(
(
set_test_value
(
pt
.
scalar
(
dtype
=
bool
),
True
),),
((
pt
.
scalar
(
dtype
=
bool
),
True
),),
bool
,
bool
,
),
),
],
],
)
)
def
test_MakeVector
(
vals
,
dtype
):
def
test_MakeVector
(
vals
,
dtype
):
vals
,
vals_test
=
zip
(
*
vals
,
strict
=
True
)
g
=
ptb
.
MakeVector
(
dtype
)(
*
vals
)
g
=
ptb
.
MakeVector
(
dtype
)(
*
vals
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
vals
,
[
[
g
],
i
.
tag
.
test_value
vals_test
,
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
@pytest.mark.parametrize
(
def
test_ARange
():
"start, stop, step, dtype"
,
start
,
start_test
=
pt
.
lscalar
(),
np
.
array
(
1
)
[
stop
,
stop_tset
=
pt
.
lscalar
(),
np
.
array
(
10
)
(
step
,
step_test
=
pt
.
lscalar
(),
np
.
array
(
3
)
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
)),
dtype
=
config
.
floatX
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
10
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
3
)),
config
.
floatX
,
),
],
)
def
test_ARange
(
start
,
stop
,
step
,
dtype
):
g
=
ptb
.
ARange
(
dtype
)(
start
,
stop
,
step
)
g
=
ptb
.
ARange
(
dtype
)(
start
,
stop
,
step
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
start
,
stop
,
step
],
[
g
,
i
.
tag
.
test_value
[
start_test
,
stop_tset
,
step_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -184,80 +144,60 @@ def test_ARange(start, stop, step, dtype):
...
@@ -184,80 +144,60 @@ def test_ARange(start, stop, step, dtype):
[
[
(
(
(
(
set_test_value
(
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
1
,
2
))
.
astype
(
config
.
floatX
)),
pt
.
matrix
(),
rng
.
normal
(
size
=
(
1
,
2
))
.
astype
(
config
.
floatX
)
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
1
,
2
))
.
astype
(
config
.
floatX
)),
),
set_test_value
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
1
,
2
))
.
astype
(
config
.
floatX
)
),
),
),
0
,
0
,
),
),
(
(
(
(
set_test_value
(
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
2
,
1
))
.
astype
(
config
.
floatX
)),
pt
.
matrix
(),
rng
.
normal
(
size
=
(
2
,
1
))
.
astype
(
config
.
floatX
)
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
3
,
1
))
.
astype
(
config
.
floatX
)),
),
set_test_value
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
3
,
1
))
.
astype
(
config
.
floatX
)
),
),
),
0
,
0
,
),
),
(
(
(
(
set_test_value
(
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
1
,
2
))
.
astype
(
config
.
floatX
)),
pt
.
matrix
(),
rng
.
normal
(
size
=
(
1
,
2
))
.
astype
(
config
.
floatX
)
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
1
,
2
))
.
astype
(
config
.
floatX
)),
),
set_test_value
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
1
,
2
))
.
astype
(
config
.
floatX
)
),
),
),
1
,
1
,
),
),
(
(
(
(
set_test_value
(
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
2
,
2
))
.
astype
(
config
.
floatX
)),
pt
.
matrix
(),
rng
.
normal
(
size
=
(
2
,
2
))
.
astype
(
config
.
floatX
)
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
2
,
1
))
.
astype
(
config
.
floatX
)),
),
set_test_value
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
2
,
1
))
.
astype
(
config
.
floatX
)
),
),
),
1
,
1
,
),
),
],
],
)
)
def
test_Join
(
vals
,
axis
):
def
test_Join
(
vals
,
axis
):
vals
,
vals_test
=
zip
(
*
vals
,
strict
=
True
)
g
=
pt
.
join
(
axis
,
*
vals
)
g
=
pt
.
join
(
axis
,
*
vals
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
vals
,
[
g
,
i
.
tag
.
test_value
vals_test
,
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
def
test_Join_view
():
def
test_Join_view
():
vals
=
(
vals
,
vals_test
=
zip
(
set_test_value
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
2
,
2
))
.
astype
(
config
.
floatX
)),
*
(
set_test_value
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
2
,
2
))
.
astype
(
config
.
floatX
)),
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
2
,
2
))
.
astype
(
config
.
floatX
)),
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
2
,
2
))
.
astype
(
config
.
floatX
)),
),
strict
=
True
,
)
)
g
=
ptb
.
Join
(
view
=
1
)(
1
,
*
vals
)
g
=
ptb
.
Join
(
view
=
1
)(
1
,
*
vals
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
with
pytest
.
raises
(
NotImplementedError
):
with
pytest
.
raises
(
NotImplementedError
):
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
vals
,
[
g
,
i
.
tag
.
test_value
vals_test
,
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -267,57 +207,47 @@ def test_Join_view():
...
@@ -267,57 +207,47 @@ def test_Join_view():
(
(
0
,
0
,
0
,
0
,
set_test_value
(
pt
.
vector
(),
rng
.
normal
(
size
=
20
)
.
astype
(
config
.
floatX
)),
(
pt
.
vector
(),
rng
.
normal
(
size
=
20
)
.
astype
(
config
.
floatX
)),
set_test_value
(
pt
.
vector
(
dtype
=
"int64"
),
[]),
(
pt
.
vector
(
dtype
=
"int64"
),
[]),
),
),
(
(
5
,
5
,
0
,
0
,
set_test_value
(
pt
.
vector
(),
rng
.
normal
(
size
=
5
)
.
astype
(
config
.
floatX
)),
(
pt
.
vector
(),
rng
.
normal
(
size
=
5
)
.
astype
(
config
.
floatX
)),
set_test_value
(
(
pt
.
vector
(
dtype
=
"int64"
),
rng
.
multinomial
(
5
,
np
.
ones
(
5
)
/
5
)),
pt
.
vector
(
dtype
=
"int64"
),
rng
.
multinomial
(
5
,
np
.
ones
(
5
)
/
5
)
),
),
),
(
(
5
,
5
,
0
,
0
,
set_test_value
(
pt
.
vector
(),
rng
.
normal
(
size
=
10
)
.
astype
(
config
.
floatX
)),
(
pt
.
vector
(),
rng
.
normal
(
size
=
10
)
.
astype
(
config
.
floatX
)),
set_test_value
(
(
pt
.
vector
(
dtype
=
"int64"
),
rng
.
multinomial
(
10
,
np
.
ones
(
5
)
/
5
)),
pt
.
vector
(
dtype
=
"int64"
),
rng
.
multinomial
(
10
,
np
.
ones
(
5
)
/
5
)
),
),
),
(
(
5
,
5
,
-
1
,
-
1
,
set_test_value
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
11
,
7
))
.
astype
(
config
.
floatX
)),
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
11
,
7
))
.
astype
(
config
.
floatX
)),
set_test_value
(
(
pt
.
vector
(
dtype
=
"int64"
),
rng
.
multinomial
(
7
,
np
.
ones
(
5
)
/
5
)),
pt
.
vector
(
dtype
=
"int64"
),
rng
.
multinomial
(
7
,
np
.
ones
(
5
)
/
5
)
),
),
),
(
(
5
,
5
,
-
2
,
-
2
,
set_test_value
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
11
,
7
))
.
astype
(
config
.
floatX
)),
(
pt
.
matrix
(),
rng
.
normal
(
size
=
(
11
,
7
))
.
astype
(
config
.
floatX
)),
set_test_value
(
(
pt
.
vector
(
dtype
=
"int64"
),
rng
.
multinomial
(
11
,
np
.
ones
(
5
)
/
5
)),
pt
.
vector
(
dtype
=
"int64"
),
rng
.
multinomial
(
11
,
np
.
ones
(
5
)
/
5
)
),
),
),
],
],
)
)
def
test_Split
(
n_splits
,
axis
,
values
,
sizes
):
def
test_Split
(
n_splits
,
axis
,
values
,
sizes
):
values
,
values_test
=
values
sizes
,
sizes_test
=
sizes
g
=
pt
.
split
(
values
,
sizes
,
n_splits
,
axis
=
axis
)
g
=
pt
.
split
(
values
,
sizes
,
n_splits
,
axis
=
axis
)
assert
len
(
g
)
==
n_splits
assert
len
(
g
)
==
n_splits
if
n_splits
==
0
:
if
n_splits
==
0
:
return
return
g_fg
=
FunctionGraph
(
outputs
=
[
g
]
if
n_splits
==
1
else
g
)
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
values
,
sizes
],
[
g
,
i
.
tag
.
test_value
[
values_test
,
sizes_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -349,34 +279,27 @@ def test_Split_view():
...
@@ -349,34 +279,27 @@ def test_Split_view():
"val, offset"
,
"val, offset"
,
[
[
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
10
*
10
,
dtype
=
config
.
floatX
)
.
reshape
((
10
,
10
))),
pt
.
matrix
(),
np
.
arange
(
10
*
10
,
dtype
=
config
.
floatX
)
.
reshape
((
10
,
10
))
),
0
,
0
,
),
),
(
(
set_test_value
(
(
pt
.
matrix
(),
np
.
arange
(
10
*
10
,
dtype
=
config
.
floatX
)
.
reshape
((
10
,
10
))),
pt
.
matrix
(),
np
.
arange
(
10
*
10
,
dtype
=
config
.
floatX
)
.
reshape
((
10
,
10
))
),
-
1
,
-
1
,
),
),
(
(
set_test_value
(
pt
.
vector
(),
np
.
arange
(
10
,
dtype
=
config
.
floatX
)),
(
pt
.
vector
(),
np
.
arange
(
10
,
dtype
=
config
.
floatX
)),
0
,
0
,
),
),
],
],
)
)
def
test_ExtractDiag
(
val
,
offset
):
def
test_ExtractDiag
(
val
,
offset
):
val
,
val_test
=
val
g
=
pt
.
diag
(
val
,
offset
)
g
=
pt
.
diag
(
val
,
offset
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
val
],
[
g
,
i
.
tag
.
test_value
[
val_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
...
@@ -407,30 +330,28 @@ def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
...
@@ -407,30 +330,28 @@ def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"n, m, k, dtype"
,
"n, m, k, dtype"
,
[
[
(
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
np
.
int64
)),
None
,
0
,
None
),
((
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
np
.
int64
)),
None
,
0
,
None
),
(
(
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
np
.
int64
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
np
.
int64
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
2
,
dtype
=
np
.
int64
)),
(
pt
.
lscalar
(),
np
.
array
(
2
,
dtype
=
np
.
int64
)),
0
,
0
,
"float32"
,
"float32"
,
),
),
(
(
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
np
.
int64
)),
(
pt
.
lscalar
(),
np
.
array
(
1
,
dtype
=
np
.
int64
)),
set_test_value
(
pt
.
lscalar
(),
np
.
array
(
2
,
dtype
=
np
.
int64
)),
(
pt
.
lscalar
(),
np
.
array
(
2
,
dtype
=
np
.
int64
)),
1
,
1
,
"int64"
,
"int64"
,
),
),
],
],
)
)
def
test_Eye
(
n
,
m
,
k
,
dtype
):
def
test_Eye
(
n
,
m
,
k
,
dtype
):
n
,
n_test
=
n
m
,
m_test
=
m
if
m
is
not
None
else
(
None
,
None
)
g
=
pt
.
eye
(
n
,
m
,
k
,
dtype
=
dtype
)
g
=
pt
.
eye
(
n
,
m
,
k
,
dtype
=
dtype
)
g_fg
=
FunctionGraph
(
outputs
=
[
g
])
compare_numba_and_py
(
compare_numba_and_py
(
g_fg
,
[
n
,
m
]
if
m
is
not
None
else
[
n
],
[
g
,
i
.
tag
.
test_value
[
n_test
,
m_test
]
if
m
is
not
None
else
[
n_test
],
for
i
in
g_fg
.
inputs
if
not
isinstance
(
i
,
SharedVariable
|
Constant
)
],
)
)
tests/link/pytorch/test_basic.py
浏览文件 @
cc8c4992
...
@@ -9,10 +9,10 @@ import pytensor.tensor.basic as ptb
...
@@ -9,10 +9,10 @@ import pytensor.tensor.basic as ptb
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.builders
import
OpFromGraph
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.compile.mode
import
PYTORCH
,
Mode
from
pytensor.compile.mode
import
PYTORCH
,
Mode
from
pytensor.compile.sharedvalue
import
SharedVariable
,
shared
from
pytensor.compile.sharedvalue
import
shared
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph
import
RewriteDatabaseQuery
from
pytensor.graph.basic
import
Apply
from
pytensor.graph.basic
import
Apply
,
Variable
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.ifelse
import
ifelse
from
pytensor.ifelse
import
ifelse
...
@@ -39,10 +39,10 @@ py_mode = Mode(linker="py", optimizer=None)
...
@@ -39,10 +39,10 @@ py_mode = Mode(linker="py", optimizer=None)
def
compare_pytorch_and_py
(
def
compare_pytorch_and_py
(
fgraph
:
FunctionGraph
,
graph_inputs
:
Iterable
[
Variable
],
graph_outputs
:
Variable
|
Iterable
[
Variable
],
test_inputs
:
Iterable
,
test_inputs
:
Iterable
,
assert_fn
:
Callable
|
None
=
None
,
assert_fn
:
Callable
|
None
=
None
,
must_be_device_array
:
bool
=
True
,
pytorch_mode
=
pytorch_mode
,
pytorch_mode
=
pytorch_mode
,
py_mode
=
py_mode
,
py_mode
=
py_mode
,
):
):
...
@@ -50,8 +50,10 @@ def compare_pytorch_and_py(
...
@@ -50,8 +50,10 @@ def compare_pytorch_and_py(
Parameters
Parameters
----------
----------
fgraph: FunctionGraph
graph_inputs
PyTensor function Graph object
Symbolic inputs to the graph
graph_outputs:
Symbolic outputs of the graph
test_inputs: iter
test_inputs: iter
Numerical inputs for testing the function graph
Numerical inputs for testing the function graph
assert_fn: func, opt
assert_fn: func, opt
...
@@ -63,24 +65,22 @@ def compare_pytorch_and_py(
...
@@ -63,24 +65,22 @@ def compare_pytorch_and_py(
if
assert_fn
is
None
:
if
assert_fn
is
None
:
assert_fn
=
partial
(
np
.
testing
.
assert_allclose
)
assert_fn
=
partial
(
np
.
testing
.
assert_allclose
)
fn_inputs
=
[
i
for
i
in
fgraph
.
inputs
if
not
isinstance
(
i
,
SharedVariable
)]
if
any
(
inp
.
owner
is
not
None
for
inp
in
graph_inputs
):
raise
ValueError
(
"Inputs must be root variables"
)
pytensor_torch_fn
=
function
(
fn_inputs
,
fgraph
.
outputs
,
mode
=
pytorch_mode
)
pytensor_torch_fn
=
function
(
graph_inputs
,
graph_
outputs
,
mode
=
pytorch_mode
)
pytorch_res
=
pytensor_torch_fn
(
*
test_inputs
)
pytorch_res
=
pytensor_torch_fn
(
*
test_inputs
)
if
isinstance
(
pytorch_res
,
list
):
pytensor_py_fn
=
function
(
graph_inputs
,
graph_outputs
,
mode
=
py_mode
)
assert
all
(
isinstance
(
res
,
np
.
ndarray
)
for
res
in
pytorch_res
)
else
:
assert
isinstance
(
pytorch_res
,
np
.
ndarray
)
pytensor_py_fn
=
function
(
fn_inputs
,
fgraph
.
outputs
,
mode
=
py_mode
)
py_res
=
pytensor_py_fn
(
*
test_inputs
)
py_res
=
pytensor_py_fn
(
*
test_inputs
)
if
len
(
fgraph
.
outputs
)
>
1
:
if
isinstance
(
graph_outputs
,
list
|
tuple
)
:
for
pytorch_res_i
,
py_res_i
in
zip
(
pytorch_res
,
py_res
,
strict
=
True
):
for
pytorch_res_i
,
py_res_i
in
zip
(
pytorch_res
,
py_res
,
strict
=
True
):
assert
not
isinstance
(
pytorch_res_i
,
torch
.
Tensor
)
assert_fn
(
pytorch_res_i
,
py_res_i
)
assert_fn
(
pytorch_res_i
,
py_res_i
)
else
:
else
:
assert_fn
(
pytorch_res
[
0
],
py_res
[
0
])
assert
not
isinstance
(
pytorch_res
,
torch
.
Tensor
)
assert_fn
(
pytorch_res
,
py_res
)
return
pytensor_torch_fn
,
pytorch_res
return
pytensor_torch_fn
,
pytorch_res
...
@@ -231,7 +231,8 @@ def test_alloc_and_empty():
...
@@ -231,7 +231,8 @@ def test_alloc_and_empty():
v
=
vector
(
"v"
,
shape
=
(
3
,),
dtype
=
"float64"
)
v
=
vector
(
"v"
,
shape
=
(
3
,),
dtype
=
"float64"
)
out
=
alloc
(
v
,
dim0
,
dim1
,
3
)
out
=
alloc
(
v
,
dim0
,
dim1
,
3
)
compare_pytorch_and_py
(
compare_pytorch_and_py
(
FunctionGraph
([
v
,
dim1
],
[
out
]),
[
v
,
dim1
],
[
out
],
[
np
.
array
([
1
,
2
,
3
]),
np
.
array
(
7
)],
[
np
.
array
([
1
,
2
,
3
]),
np
.
array
(
7
)],
)
)
...
@@ -244,7 +245,8 @@ def test_arange():
...
@@ -244,7 +245,8 @@ def test_arange():
out
=
arange
(
start
,
stop
,
step
,
dtype
=
"int16"
)
out
=
arange
(
start
,
stop
,
step
,
dtype
=
"int16"
)
compare_pytorch_and_py
(
compare_pytorch_and_py
(
FunctionGraph
([
start
,
stop
,
step
],
[
out
]),
[
start
,
stop
,
step
],
[
out
],
[
np
.
array
(
1
),
np
.
array
(
10
),
np
.
array
(
2
)],
[
np
.
array
(
1
),
np
.
array
(
10
),
np
.
array
(
2
)],
)
)
...
@@ -254,16 +256,18 @@ def test_pytorch_Join():
...
@@ -254,16 +256,18 @@ def test_pytorch_Join():
b
=
matrix
(
"b"
)
b
=
matrix
(
"b"
)
x
=
ptb
.
join
(
0
,
a
,
b
)
x
=
ptb
.
join
(
0
,
a
,
b
)
x_fg
=
FunctionGraph
([
a
,
b
],
[
x
])
compare_pytorch_and_py
(
compare_pytorch_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
],
],
)
)
compare_pytorch_and_py
(
compare_pytorch_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
]]
.
astype
(
config
.
floatX
),
...
@@ -271,16 +275,18 @@ def test_pytorch_Join():
...
@@ -271,16 +275,18 @@ def test_pytorch_Join():
)
)
x
=
ptb
.
join
(
1
,
a
,
b
)
x
=
ptb
.
join
(
1
,
a
,
b
)
x_fg
=
FunctionGraph
([
a
,
b
],
[
x
])
compare_pytorch_and_py
(
compare_pytorch_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
,
3.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
4.0
,
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
],
],
)
)
compare_pytorch_and_py
(
compare_pytorch_and_py
(
x_fg
,
[
a
,
b
],
[
x
],
[
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
np
.
c_
[[
5.0
,
6.0
]]
.
astype
(
config
.
floatX
),
...
@@ -309,9 +315,8 @@ def test_eye(dtype):
...
@@ -309,9 +315,8 @@ def test_eye(dtype):
def
test_pytorch_MakeVector
():
def
test_pytorch_MakeVector
():
x
=
ptb
.
make_vector
(
1
,
2
,
3
)
x
=
ptb
.
make_vector
(
1
,
2
,
3
)
x_fg
=
FunctionGraph
([],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[])
compare_pytorch_and_py
(
[],
[
x
]
,
[])
def
test_pytorch_ifelse
():
def
test_pytorch_ifelse
():
...
@@ -320,15 +325,13 @@ def test_pytorch_ifelse():
...
@@ -320,15 +325,13 @@ def test_pytorch_ifelse():
a
=
scalar
(
"a"
)
a
=
scalar
(
"a"
)
x
=
ifelse
(
a
<
0.5
,
tuple
(
np
.
r_
[
p1_vals
,
p2_vals
]),
tuple
(
np
.
r_
[
p2_vals
,
p1_vals
]))
x
=
ifelse
(
a
<
0.5
,
tuple
(
np
.
r_
[
p1_vals
,
p2_vals
]),
tuple
(
np
.
r_
[
p2_vals
,
p1_vals
]))
x_fg
=
FunctionGraph
([
a
],
x
)
compare_pytorch_and_py
(
x_fg
,
np
.
array
([
0.2
],
dtype
=
config
.
floatX
))
compare_pytorch_and_py
(
[
a
],
x
,
np
.
array
([
0.2
],
dtype
=
config
.
floatX
))
a
=
scalar
(
"a"
)
a
=
scalar
(
"a"
)
x
=
ifelse
(
a
<
0.4
,
tuple
(
np
.
r_
[
p1_vals
,
p2_vals
]),
tuple
(
np
.
r_
[
p2_vals
,
p1_vals
]))
x
=
ifelse
(
a
<
0.4
,
tuple
(
np
.
r_
[
p1_vals
,
p2_vals
]),
tuple
(
np
.
r_
[
p2_vals
,
p1_vals
]))
x_fg
=
FunctionGraph
([
a
],
x
)
compare_pytorch_and_py
(
x_fg
,
np
.
array
([
0.5
],
dtype
=
config
.
floatX
))
compare_pytorch_and_py
(
[
a
],
x
,
np
.
array
([
0.5
],
dtype
=
config
.
floatX
))
def
test_pytorch_OpFromGraph
():
def
test_pytorch_OpFromGraph
():
...
@@ -343,8 +346,7 @@ def test_pytorch_OpFromGraph():
...
@@ -343,8 +346,7 @@ def test_pytorch_OpFromGraph():
yv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
3
yv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
3
zv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
5
zv
=
np
.
ones
((
2
,
2
),
dtype
=
config
.
floatX
)
*
5
f
=
FunctionGraph
([
x
,
y
,
z
],
[
out
])
compare_pytorch_and_py
([
x
,
y
,
z
],
[
out
],
[
xv
,
yv
,
zv
])
compare_pytorch_and_py
(
f
,
[
xv
,
yv
,
zv
])
def
test_pytorch_link_references
():
def
test_pytorch_link_references
():
...
@@ -380,15 +382,13 @@ def test_pytorch_link_references():
...
@@ -380,15 +382,13 @@ def test_pytorch_link_references():
def
test_pytorch_scipy
():
def
test_pytorch_scipy
():
x
=
vector
(
"a"
,
shape
=
(
3
,))
x
=
vector
(
"a"
,
shape
=
(
3
,))
out
=
expit
(
x
)
out
=
expit
(
x
)
f
=
FunctionGraph
([
x
],
[
out
])
compare_pytorch_and_py
([
x
],
[
out
],
[
np
.
random
.
rand
(
3
)])
compare_pytorch_and_py
(
f
,
[
np
.
random
.
rand
(
3
)])
def
test_pytorch_softplus
():
def
test_pytorch_softplus
():
x
=
vector
(
"a"
,
shape
=
(
3
,))
x
=
vector
(
"a"
,
shape
=
(
3
,))
out
=
softplus
(
x
)
out
=
softplus
(
x
)
f
=
FunctionGraph
([
x
],
[
out
])
compare_pytorch_and_py
([
x
],
[
out
],
[
np
.
random
.
rand
(
3
)])
compare_pytorch_and_py
(
f
,
[
np
.
random
.
rand
(
3
)])
def
test_ScalarLoop
():
def
test_ScalarLoop
():
...
@@ -436,13 +436,15 @@ def test_ScalarLoop_Elemwise_single_carries():
...
@@ -436,13 +436,15 @@ def test_ScalarLoop_Elemwise_single_carries():
x0
=
pt
.
vector
(
"x0"
,
dtype
=
"float32"
)
x0
=
pt
.
vector
(
"x0"
,
dtype
=
"float32"
)
state
,
done
=
op
(
n_steps
,
x0
)
state
,
done
=
op
(
n_steps
,
x0
)
f
=
FunctionGraph
([
n_steps
,
x0
],
[
state
,
done
])
args
=
[
args
=
[
np
.
array
(
10
)
.
astype
(
"int32"
),
np
.
array
(
10
)
.
astype
(
"int32"
),
np
.
arange
(
0
,
5
)
.
astype
(
"float32"
),
np
.
arange
(
0
,
5
)
.
astype
(
"float32"
),
]
]
compare_pytorch_and_py
(
compare_pytorch_and_py
(
f
,
args
,
assert_fn
=
partial
(
np
.
testing
.
assert_allclose
,
rtol
=
1e-6
)
[
n_steps
,
x0
],
[
state
,
done
],
args
,
assert_fn
=
partial
(
np
.
testing
.
assert_allclose
,
rtol
=
1e-6
),
)
)
...
@@ -462,14 +464,16 @@ def test_ScalarLoop_Elemwise_multi_carries():
...
@@ -462,14 +464,16 @@ def test_ScalarLoop_Elemwise_multi_carries():
x1
=
pt
.
tensor
(
"c0"
,
dtype
=
"float32"
,
shape
=
(
7
,
3
,
1
))
x1
=
pt
.
tensor
(
"c0"
,
dtype
=
"float32"
,
shape
=
(
7
,
3
,
1
))
*
states
,
done
=
op
(
n_steps
,
x0
,
x1
)
*
states
,
done
=
op
(
n_steps
,
x0
,
x1
)
f
=
FunctionGraph
([
n_steps
,
x0
,
x1
],
[
*
states
,
done
])
args
=
[
args
=
[
np
.
array
(
10
)
.
astype
(
"int32"
),
np
.
array
(
10
)
.
astype
(
"int32"
),
np
.
arange
(
0
,
5
)
.
astype
(
"float32"
),
np
.
arange
(
0
,
5
)
.
astype
(
"float32"
),
np
.
random
.
rand
(
7
,
3
,
1
)
.
astype
(
"float32"
),
np
.
random
.
rand
(
7
,
3
,
1
)
.
astype
(
"float32"
),
]
]
compare_pytorch_and_py
(
compare_pytorch_and_py
(
f
,
args
,
assert_fn
=
partial
(
np
.
testing
.
assert_allclose
,
rtol
=
1e-6
)
[
n_steps
,
x0
,
x1
],
[
*
states
,
done
],
args
,
assert_fn
=
partial
(
np
.
testing
.
assert_allclose
,
rtol
=
1e-6
),
)
)
...
@@ -518,6 +522,5 @@ def test_Split(n_splits, axis, values, sizes):
...
@@ -518,6 +522,5 @@ def test_Split(n_splits, axis, values, sizes):
assert
len
(
g
)
==
n_splits
assert
len
(
g
)
==
n_splits
if
n_splits
==
0
:
if
n_splits
==
0
:
return
return
g_fg
=
FunctionGraph
(
inputs
=
[
i
,
s
],
outputs
=
[
g
]
if
n_splits
==
1
else
g
)
compare_pytorch_and_py
(
g_f
g
,
[
values
,
sizes
])
compare_pytorch_and_py
(
[
i
,
s
],
g
,
[
values
,
sizes
])
tests/link/pytorch/test_blas.py
浏览文件 @
cc8c4992
...
@@ -2,7 +2,6 @@ import numpy as np
...
@@ -2,7 +2,6 @@ import numpy as np
import
pytest
import
pytest
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
blas
as
pt_blas
from
pytensor.tensor
import
blas
as
pt_blas
from
pytensor.tensor.type
import
tensor3
from
pytensor.tensor.type
import
tensor3
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
...
@@ -15,8 +14,8 @@ def test_pytorch_BatchedDot():
...
@@ -15,8 +14,8 @@ def test_pytorch_BatchedDot():
b
=
tensor3
(
"b"
)
b
=
tensor3
(
"b"
)
b_test
=
np
.
linspace
(
1
,
-
1
,
10
*
3
*
2
)
.
astype
(
config
.
floatX
)
.
reshape
((
10
,
3
,
2
))
b_test
=
np
.
linspace
(
1
,
-
1
,
10
*
3
*
2
)
.
astype
(
config
.
floatX
)
.
reshape
((
10
,
3
,
2
))
out
=
pt_blas
.
BatchedDot
()(
a
,
b
)
out
=
pt_blas
.
BatchedDot
()(
a
,
b
)
fgraph
=
FunctionGraph
([
a
,
b
],
[
out
])
pytensor_pytorch_fn
,
_
=
compare_pytorch_and_py
(
fgraph
,
[
a_test
,
b_test
])
pytensor_pytorch_fn
,
_
=
compare_pytorch_and_py
(
[
a
,
b
],
[
out
]
,
[
a_test
,
b_test
])
# A dimension mismatch should raise a TypeError for compatibility
# A dimension mismatch should raise a TypeError for compatibility
inputs
=
[
a_test
[:
-
1
],
b_test
]
inputs
=
[
a_test
[:
-
1
],
b_test
]
...
...
tests/link/pytorch/test_elemwise.py
浏览文件 @
cc8c4992
...
@@ -5,7 +5,6 @@ import pytensor
...
@@ -5,7 +5,6 @@ import pytensor
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
import
pytensor.tensor.math
as
ptm
import
pytensor.tensor.math
as
ptm
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.scalar.basic
import
ScalarOp
,
get_scalar_type
from
pytensor.scalar.basic
import
ScalarOp
,
get_scalar_type
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.special
import
SoftmaxGrad
,
log_softmax
,
softmax
from
pytensor.tensor.special
import
SoftmaxGrad
,
log_softmax
,
softmax
...
@@ -20,17 +19,23 @@ def test_pytorch_Dimshuffle():
...
@@ -20,17 +19,23 @@ def test_pytorch_Dimshuffle():
a_pt
=
matrix
(
"a"
)
a_pt
=
matrix
(
"a"
)
x
=
a_pt
.
T
x
=
a_pt
.
T
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_pytorch_and_py
(
[
a_pt
],
[
x
],
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)]
)
x
=
a_pt
.
dimshuffle
([
0
,
1
,
"x"
])
x
=
a_pt
.
dimshuffle
([
0
,
1
,
"x"
])
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_pytorch_and_py
(
[
a_pt
],
[
x
],
[
np
.
c_
[[
1.0
,
2.0
],
[
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)]
)
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
a_pt
=
tensor
(
dtype
=
config
.
floatX
,
shape
=
(
None
,
1
))
x
=
a_pt
.
dimshuffle
((
0
,))
x
=
a_pt
.
dimshuffle
((
0
,))
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)])
compare_pytorch_and_py
(
[
a_pt
],
[
x
],
[
np
.
c_
[[
1.0
,
2.0
,
3.0
,
4.0
]]
.
astype
(
config
.
floatX
)]
)
def
test_multiple_input_output
():
def
test_multiple_input_output
():
...
@@ -38,24 +43,21 @@ def test_multiple_input_output():
...
@@ -38,24 +43,21 @@ def test_multiple_input_output():
y
=
vector
(
"y"
)
y
=
vector
(
"y"
)
out
=
pt
.
mul
(
x
,
y
)
out
=
pt
.
mul
(
x
,
y
)
fg
=
FunctionGraph
(
outputs
=
[
out
],
clone
=
False
)
compare_pytorch_and_py
([
x
,
y
],
[
out
],
[[
1.5
],
[
2.5
]])
compare_pytorch_and_py
(
fg
,
[[
1.5
],
[
2.5
]])
x
=
vector
(
"x"
)
x
=
vector
(
"x"
)
y
=
vector
(
"y"
)
y
=
vector
(
"y"
)
div
=
pt
.
int_div
(
x
,
y
)
div
=
pt
.
int_div
(
x
,
y
)
pt_sum
=
pt
.
add
(
y
,
x
)
pt_sum
=
pt
.
add
(
y
,
x
)
fg
=
FunctionGraph
(
outputs
=
[
div
,
pt_sum
],
clone
=
False
)
compare_pytorch_and_py
([
x
,
y
],
[
div
,
pt_sum
],
[[
1.5
],
[
2.5
]])
compare_pytorch_and_py
(
fg
,
[[
1.5
],
[
2.5
]])
def
test_pytorch_elemwise
():
def
test_pytorch_elemwise
():
x
=
pt
.
vector
(
"x"
)
x
=
pt
.
vector
(
"x"
)
out
=
pt
.
log
(
1
-
x
)
out
=
pt
.
log
(
1
-
x
)
fg
=
FunctionGraph
([
x
],
[
out
])
compare_pytorch_and_py
([
x
],
[
out
],
[[
0.9
,
0.9
]])
compare_pytorch_and_py
(
fg
,
[[
0.9
,
0.9
]])
@pytest.mark.parametrize
(
"fn"
,
[
ptm
.
sum
,
ptm
.
prod
,
ptm
.
max
,
ptm
.
min
])
@pytest.mark.parametrize
(
"fn"
,
[
ptm
.
sum
,
ptm
.
prod
,
ptm
.
max
,
ptm
.
min
])
...
@@ -81,9 +83,8 @@ def test_pytorch_careduce(fn, axis):
...
@@ -81,9 +83,8 @@ def test_pytorch_careduce(fn, axis):
)
.
astype
(
config
.
floatX
)
)
.
astype
(
config
.
floatX
)
x
=
fn
(
a_pt
,
axis
=
axis
)
x
=
fn
(
a_pt
,
axis
=
axis
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
test_value
])
compare_pytorch_and_py
(
[
a_pt
],
[
x
]
,
[
test_value
])
@pytest.mark.parametrize
(
"fn"
,
[
ptm
.
any
,
ptm
.
all
])
@pytest.mark.parametrize
(
"fn"
,
[
ptm
.
any
,
ptm
.
all
])
...
@@ -93,9 +94,8 @@ def test_pytorch_any_all(fn, axis):
...
@@ -93,9 +94,8 @@ def test_pytorch_any_all(fn, axis):
test_value
=
np
.
array
([[
True
,
False
,
True
],
[
False
,
True
,
True
]])
test_value
=
np
.
array
([[
True
,
False
,
True
],
[
False
,
True
,
True
]])
x
=
fn
(
a_pt
,
axis
=
axis
)
x
=
fn
(
a_pt
,
axis
=
axis
)
x_fg
=
FunctionGraph
([
a_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
test_value
])
compare_pytorch_and_py
(
[
a_pt
],
[
x
]
,
[
test_value
])
@pytest.mark.parametrize
(
"dtype"
,
[
"float64"
,
"int64"
])
@pytest.mark.parametrize
(
"dtype"
,
[
"float64"
,
"int64"
])
...
@@ -103,7 +103,6 @@ def test_pytorch_any_all(fn, axis):
...
@@ -103,7 +103,6 @@ def test_pytorch_any_all(fn, axis):
def
test_softmax
(
axis
,
dtype
):
def
test_softmax
(
axis
,
dtype
):
x
=
matrix
(
"x"
,
dtype
=
dtype
)
x
=
matrix
(
"x"
,
dtype
=
dtype
)
out
=
softmax
(
x
,
axis
=
axis
)
out
=
softmax
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
test_input
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
test_input
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
if
dtype
==
"int64"
:
if
dtype
==
"int64"
:
...
@@ -111,9 +110,9 @@ def test_softmax(axis, dtype):
...
@@ -111,9 +110,9 @@ def test_softmax(axis, dtype):
NotImplementedError
,
NotImplementedError
,
match
=
"Pytorch Softmax is not currently implemented for non-float types."
,
match
=
"Pytorch Softmax is not currently implemented for non-float types."
,
):
):
compare_pytorch_and_py
(
fgraph
,
[
test_input
])
compare_pytorch_and_py
(
[
x
],
[
out
]
,
[
test_input
])
else
:
else
:
compare_pytorch_and_py
(
fgraph
,
[
test_input
])
compare_pytorch_and_py
(
[
x
],
[
out
]
,
[
test_input
])
@pytest.mark.parametrize
(
"dtype"
,
[
"float64"
,
"int64"
])
@pytest.mark.parametrize
(
"dtype"
,
[
"float64"
,
"int64"
])
...
@@ -121,7 +120,6 @@ def test_softmax(axis, dtype):
...
@@ -121,7 +120,6 @@ def test_softmax(axis, dtype):
def
test_logsoftmax
(
axis
,
dtype
):
def
test_logsoftmax
(
axis
,
dtype
):
x
=
matrix
(
"x"
,
dtype
=
dtype
)
x
=
matrix
(
"x"
,
dtype
=
dtype
)
out
=
log_softmax
(
x
,
axis
=
axis
)
out
=
log_softmax
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
test_input
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
test_input
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
if
dtype
==
"int64"
:
if
dtype
==
"int64"
:
...
@@ -129,9 +127,9 @@ def test_logsoftmax(axis, dtype):
...
@@ -129,9 +127,9 @@ def test_logsoftmax(axis, dtype):
NotImplementedError
,
NotImplementedError
,
match
=
"Pytorch LogSoftmax is not currently implemented for non-float types."
,
match
=
"Pytorch LogSoftmax is not currently implemented for non-float types."
,
):
):
compare_pytorch_and_py
(
fgraph
,
[
test_input
])
compare_pytorch_and_py
(
[
x
],
[
out
]
,
[
test_input
])
else
:
else
:
compare_pytorch_and_py
(
fgraph
,
[
test_input
])
compare_pytorch_and_py
(
[
x
],
[
out
]
,
[
test_input
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
...
@@ -141,16 +139,14 @@ def test_softmax_grad(axis):
...
@@ -141,16 +139,14 @@ def test_softmax_grad(axis):
sm
=
matrix
(
"sm"
)
sm
=
matrix
(
"sm"
)
sm_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
sm_value
=
np
.
arange
(
6
,
dtype
=
config
.
floatX
)
.
reshape
(
2
,
3
)
out
=
SoftmaxGrad
(
axis
=
axis
)(
dy
,
sm
)
out
=
SoftmaxGrad
(
axis
=
axis
)(
dy
,
sm
)
fgraph
=
FunctionGraph
([
dy
,
sm
],
[
out
])
compare_pytorch_and_py
([
dy
,
sm
],
[
out
],
[
dy_value
,
sm_value
])
compare_pytorch_and_py
(
fgraph
,
[
dy_value
,
sm_value
])
def
test_cast
():
def
test_cast
():
x
=
matrix
(
"x"
,
dtype
=
"float32"
)
x
=
matrix
(
"x"
,
dtype
=
"float32"
)
out
=
pt
.
cast
(
x
,
"int32"
)
out
=
pt
.
cast
(
x
,
"int32"
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
_
,
[
res
]
=
compare_pytorch_and_py
(
_
,
[
res
]
=
compare_pytorch_and_py
(
fgraph
,
[
np
.
arange
(
6
,
dtype
=
"float32"
)
.
reshape
(
2
,
3
)]
[
x
],
[
out
]
,
[
np
.
arange
(
6
,
dtype
=
"float32"
)
.
reshape
(
2
,
3
)]
)
)
assert
res
.
dtype
==
np
.
int32
assert
res
.
dtype
==
np
.
int32
...
...
tests/link/pytorch/test_extra_ops.py
浏览文件 @
cc8c4992
...
@@ -2,7 +2,6 @@ import numpy as np
...
@@ -2,7 +2,6 @@ import numpy as np
import
pytest
import
pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.graph
import
FunctionGraph
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
...
@@ -31,16 +30,14 @@ def test_pytorch_CumOp(axis, dtype):
...
@@ -31,16 +30,14 @@ def test_pytorch_CumOp(axis, dtype):
out
=
pt
.
cumprod
(
a
,
axis
=
axis
)
out
=
pt
.
cumprod
(
a
,
axis
=
axis
)
else
:
else
:
out
=
pt
.
cumsum
(
a
,
axis
=
axis
)
out
=
pt
.
cumsum
(
a
,
axis
=
axis
)
# Create a PyTensor `FunctionGraph`
fgraph
=
FunctionGraph
([
a
],
[
out
])
# Pass the
graph and in
puts to the testing function
# Pass the
inputs and out
puts to the testing function
compare_pytorch_and_py
(
fgraph
,
[
test_value
])
compare_pytorch_and_py
(
[
a
],
[
out
]
,
[
test_value
])
# For the second mode of CumOp
# For the second mode of CumOp
out
=
pt
.
cumprod
(
a
,
axis
=
axis
)
out
=
pt
.
cumprod
(
a
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_pytorch_and_py
(
fgraph
,
[
test_value
])
compare_pytorch_and_py
(
[
a
],
[
out
]
,
[
test_value
])
@pytest.mark.parametrize
(
"axis, repeats"
,
[(
0
,
(
1
,
2
,
3
)),
(
1
,
(
3
,
3
)),
(
None
,
3
)])
@pytest.mark.parametrize
(
"axis, repeats"
,
[(
0
,
(
1
,
2
,
3
)),
(
1
,
(
3
,
3
)),
(
None
,
3
)])
...
@@ -50,8 +47,8 @@ def test_pytorch_Repeat(axis, repeats):
...
@@ -50,8 +47,8 @@ def test_pytorch_Repeat(axis, repeats):
test_value
=
np
.
arange
(
6
,
dtype
=
"float64"
)
.
reshape
((
3
,
2
))
test_value
=
np
.
arange
(
6
,
dtype
=
"float64"
)
.
reshape
((
3
,
2
))
out
=
pt
.
repeat
(
a
,
repeats
,
axis
=
axis
)
out
=
pt
.
repeat
(
a
,
repeats
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_pytorch_and_py
(
fgraph
,
[
test_value
])
compare_pytorch_and_py
(
[
a
],
[
out
]
,
[
test_value
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
@pytest.mark.parametrize
(
"axis"
,
[
None
,
0
,
1
])
...
@@ -63,8 +60,8 @@ def test_pytorch_Unique_axis(axis):
...
@@ -63,8 +60,8 @@ def test_pytorch_Unique_axis(axis):
)
)
out
=
pt
.
unique
(
a
,
axis
=
axis
)
out
=
pt
.
unique
(
a
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
a
],
[
out
])
compare_pytorch_and_py
(
fgraph
,
[
test_value
])
compare_pytorch_and_py
(
[
a
],
[
out
]
,
[
test_value
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
,
True
])
@pytest.mark.parametrize
(
"return_inverse"
,
[
False
,
True
])
...
@@ -86,5 +83,7 @@ def test_pytorch_Unique_params(return_index, return_inverse, return_counts):
...
@@ -86,5 +83,7 @@ def test_pytorch_Unique_params(return_index, return_inverse, return_counts):
return_counts
=
return_counts
,
return_counts
=
return_counts
,
axis
=
0
,
axis
=
0
,
)
)
fgraph
=
FunctionGraph
([
a
],
[
out
[
0
]
if
isinstance
(
out
,
list
)
else
out
])
compare_pytorch_and_py
(
fgraph
,
[
test_value
])
compare_pytorch_and_py
(
[
a
],
[
out
[
0
]
if
isinstance
(
out
,
list
)
else
out
],
[
test_value
]
)
tests/link/pytorch/test_math.py
浏览文件 @
cc8c4992
import
numpy
as
np
import
numpy
as
np
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor.type
import
matrix
,
scalar
,
vector
from
pytensor.tensor.type
import
matrix
,
scalar
,
vector
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
...
@@ -20,10 +19,12 @@ def test_pytorch_dot():
...
@@ -20,10 +19,12 @@ def test_pytorch_dot():
# 2D * 2D
# 2D * 2D
out
=
A
.
dot
(
A
*
alpha
)
+
beta
*
A
out
=
A
.
dot
(
A
*
alpha
)
+
beta
*
A
fgraph
=
FunctionGraph
([
A
,
alpha
,
beta
],
[
out
])
compare_pytorch_and_py
(
fgraph
,
[
A_test
,
alpha_test
,
beta_test
])
compare_pytorch_and_py
(
[
A
,
alpha
,
beta
],
[
out
]
,
[
A_test
,
alpha_test
,
beta_test
])
# 1D * 2D and 1D * 1D
# 1D * 2D and 1D * 1D
out
=
y
.
dot
(
alpha
*
A
)
.
dot
(
x
)
+
beta
*
y
out
=
y
.
dot
(
alpha
*
A
)
.
dot
(
x
)
+
beta
*
y
fgraph
=
FunctionGraph
([
y
,
x
,
A
,
alpha
,
beta
],
[
out
])
compare_pytorch_and_py
(
fgraph
,
[
y_test
,
x_test
,
A_test
,
alpha_test
,
beta_test
])
compare_pytorch_and_py
(
[
y
,
x
,
A
,
alpha
,
beta
],
[
out
],
[
y_test
,
x_test
,
A_test
,
alpha_test
,
beta_test
]
)
tests/link/pytorch/test_nlinalg.py
浏览文件 @
cc8c4992
from
collections.abc
import
Sequence
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
pytensor.compile.function
import
function
from
pytensor.compile.function
import
function
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
nlinalg
as
pt_nla
from
pytensor.tensor
import
nlinalg
as
pt_nla
from
pytensor.tensor.type
import
matrix
from
pytensor.tensor.type
import
matrix
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
...
@@ -29,13 +26,12 @@ def matrix_test():
...
@@ -29,13 +26,12 @@ def matrix_test():
def
test_lin_alg_no_params
(
func
,
matrix_test
):
def
test_lin_alg_no_params
(
func
,
matrix_test
):
x
,
test_value
=
matrix_test
x
,
test_value
=
matrix_test
out
=
func
(
x
)
outs
=
func
(
x
)
out_fg
=
FunctionGraph
([
x
],
out
if
isinstance
(
out
,
Sequence
)
else
[
out
])
def
assert_fn
(
x
,
y
):
def
assert_fn
(
x
,
y
):
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
1e-3
)
np
.
testing
.
assert_allclose
(
x
,
y
,
rtol
=
1e-3
)
compare_pytorch_and_py
(
out_fg
,
[
test_value
],
assert_fn
=
assert_fn
)
compare_pytorch_and_py
(
[
x
],
outs
,
[
test_value
],
assert_fn
=
assert_fn
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -50,8 +46,8 @@ def test_lin_alg_no_params(func, matrix_test):
...
@@ -50,8 +46,8 @@ def test_lin_alg_no_params(func, matrix_test):
def
test_qr
(
mode
,
matrix_test
):
def
test_qr
(
mode
,
matrix_test
):
x
,
test_value
=
matrix_test
x
,
test_value
=
matrix_test
outs
=
pt_nla
.
qr
(
x
,
mode
=
mode
)
outs
=
pt_nla
.
qr
(
x
,
mode
=
mode
)
out_fg
=
FunctionGraph
([
x
],
outs
if
isinstance
(
outs
,
list
)
else
[
outs
])
compare_pytorch_and_py
(
out_fg
,
[
test_value
])
compare_pytorch_and_py
(
[
x
],
outs
,
[
test_value
])
@pytest.mark.parametrize
(
"compute_uv"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"compute_uv"
,
[
True
,
False
])
...
@@ -60,18 +56,16 @@ def test_svd(compute_uv, full_matrices, matrix_test):
...
@@ -60,18 +56,16 @@ def test_svd(compute_uv, full_matrices, matrix_test):
x
,
test_value
=
matrix_test
x
,
test_value
=
matrix_test
out
=
pt_nla
.
svd
(
x
,
full_matrices
=
full_matrices
,
compute_uv
=
compute_uv
)
out
=
pt_nla
.
svd
(
x
,
full_matrices
=
full_matrices
,
compute_uv
=
compute_uv
)
out_fg
=
FunctionGraph
([
x
],
out
if
isinstance
(
out
,
list
)
else
[
out
])
compare_pytorch_and_py
(
out_fg
,
[
test_value
])
compare_pytorch_and_py
(
[
x
],
out
,
[
test_value
])
def
test_pinv
():
def
test_pinv
():
x
=
matrix
(
"x"
)
x
=
matrix
(
"x"
)
x_inv
=
pt_nla
.
pinv
(
x
)
x_inv
=
pt_nla
.
pinv
(
x
)
fgraph
=
FunctionGraph
([
x
],
[
x_inv
])
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
compare_pytorch_and_py
(
fgraph
,
[
x_np
])
compare_pytorch_and_py
(
[
x
],
[
x_inv
]
,
[
x_np
])
@pytest.mark.parametrize
(
"hermitian"
,
[
False
,
True
])
@pytest.mark.parametrize
(
"hermitian"
,
[
False
,
True
])
...
@@ -106,8 +100,7 @@ def test_kron():
...
@@ -106,8 +100,7 @@ def test_kron():
y
=
matrix
(
"y"
)
y
=
matrix
(
"y"
)
z
=
pt_nla
.
kron
(
x
,
y
)
z
=
pt_nla
.
kron
(
x
,
y
)
fgraph
=
FunctionGraph
([
x
,
y
],
[
z
])
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
x_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
y_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
y_np
=
np
.
array
([[
1.0
,
2.0
],
[
3.0
,
4.0
]],
dtype
=
config
.
floatX
)
compare_pytorch_and_py
(
fgraph
,
[
x_np
,
y_np
])
compare_pytorch_and_py
(
[
x
,
y
],
[
z
]
,
[
x_np
,
y_np
])
tests/link/pytorch/test_shape.py
浏览文件 @
cc8c4992
...
@@ -2,7 +2,6 @@ import numpy as np
...
@@ -2,7 +2,6 @@ import numpy as np
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor.shape
import
Shape
,
Shape_i
,
Unbroadcast
,
reshape
from
pytensor.tensor.shape
import
Shape
,
Shape_i
,
Unbroadcast
,
reshape
from
pytensor.tensor.type
import
iscalar
,
vector
from
pytensor.tensor.type
import
iscalar
,
vector
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
...
@@ -11,29 +10,27 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py
...
@@ -11,29 +10,27 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py
def
test_pytorch_shape_ops
():
def
test_pytorch_shape_ops
():
x_np
=
np
.
zeros
((
20
,
3
))
x_np
=
np
.
zeros
((
20
,
3
))
x
=
Shape
()(
pt
.
as_tensor_variable
(
x_np
))
x
=
Shape
()(
pt
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[],
must_be_device_array
=
False
)
compare_pytorch_and_py
(
[],
[
x
],
[]
)
x
=
Shape_i
(
1
)(
pt
.
as_tensor_variable
(
x_np
))
x
=
Shape_i
(
1
)(
pt
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[],
must_be_device_array
=
False
)
compare_pytorch_and_py
(
[],
[
x
],
[]
)
def
test_pytorch_specify_shape
():
def
test_pytorch_specify_shape
():
in_pt
=
pt
.
matrix
(
"in"
)
in_pt
=
pt
.
matrix
(
"in"
)
x
=
pt
.
specify_shape
(
in_pt
,
(
4
,
None
))
x
=
pt
.
specify_shape
(
in_pt
,
(
4
,
None
))
x_fg
=
FunctionGraph
([
in_pt
],
[
x
])
compare_pytorch_and_py
([
in_pt
],
[
x
],
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)])
compare_pytorch_and_py
(
x_fg
,
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)])
# When used to assert two arrays have similar shapes
# When used to assert two arrays have similar shapes
in_pt
=
pt
.
matrix
(
"in"
)
in_pt
=
pt
.
matrix
(
"in"
)
shape_pt
=
pt
.
matrix
(
"shape"
)
shape_pt
=
pt
.
matrix
(
"shape"
)
x
=
pt
.
specify_shape
(
in_pt
,
shape_pt
.
shape
)
x
=
pt
.
specify_shape
(
in_pt
,
shape_pt
.
shape
)
x_fg
=
FunctionGraph
([
in_pt
,
shape_pt
],
[
x
])
compare_pytorch_and_py
(
compare_pytorch_and_py
(
x_fg
,
[
in_pt
,
shape_pt
],
[
x
],
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
),
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)],
[
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
),
np
.
ones
((
4
,
5
))
.
astype
(
config
.
floatX
)],
)
)
...
@@ -41,21 +38,22 @@ def test_pytorch_specify_shape():
...
@@ -41,21 +38,22 @@ def test_pytorch_specify_shape():
def
test_pytorch_Reshape_constant
():
def
test_pytorch_Reshape_constant
():
a
=
vector
(
"a"
)
a
=
vector
(
"a"
)
x
=
reshape
(
a
,
(
2
,
2
))
x
=
reshape
(
a
,
(
2
,
2
))
x_fg
=
FunctionGraph
([
a
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
compare_pytorch_and_py
(
[
a
],
[
x
]
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
)])
def
test_pytorch_Reshape_dynamic
():
def
test_pytorch_Reshape_dynamic
():
a
=
vector
(
"a"
)
a
=
vector
(
"a"
)
shape_pt
=
iscalar
(
"b"
)
shape_pt
=
iscalar
(
"b"
)
x
=
reshape
(
a
,
(
shape_pt
,
shape_pt
))
x
=
reshape
(
a
,
(
shape_pt
,
shape_pt
))
x_fg
=
FunctionGraph
([
a
,
shape_pt
],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
])
compare_pytorch_and_py
(
[
a
,
shape_pt
],
[
x
],
[
np
.
r_
[
1.0
,
2.0
,
3.0
,
4.0
]
.
astype
(
config
.
floatX
),
2
]
)
def
test_pytorch_unbroadcast
():
def
test_pytorch_unbroadcast
():
x_np
=
np
.
zeros
((
20
,
1
,
1
))
x_np
=
np
.
zeros
((
20
,
1
,
1
))
x
=
Unbroadcast
(
0
,
2
)(
pt
.
as_tensor_variable
(
x_np
))
x
=
Unbroadcast
(
0
,
2
)(
pt
.
as_tensor_variable
(
x_np
))
x_fg
=
FunctionGraph
([],
[
x
])
compare_pytorch_and_py
(
x_fg
,
[])
compare_pytorch_and_py
(
[],
[
x
]
,
[])
tests/link/pytorch/test_sort.py
浏览文件 @
cc8c4992
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
pytensor.graph
import
FunctionGraph
from
pytensor.tensor
import
matrix
from
pytensor.tensor
import
matrix
from
pytensor.tensor.sort
import
argsort
,
sort
from
pytensor.tensor.sort
import
argsort
,
sort
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
...
@@ -12,6 +11,5 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py
...
@@ -12,6 +11,5 @@ from tests.link.pytorch.test_basic import compare_pytorch_and_py
def
test_sort
(
func
,
axis
):
def
test_sort
(
func
,
axis
):
x
=
matrix
(
"x"
,
shape
=
(
2
,
2
),
dtype
=
"float64"
)
x
=
matrix
(
"x"
,
shape
=
(
2
,
2
),
dtype
=
"float64"
)
out
=
func
(
x
,
axis
=
axis
)
out
=
func
(
x
,
axis
=
axis
)
fgraph
=
FunctionGraph
([
x
],
[
out
])
arr
=
np
.
array
([[
1.0
,
4.0
],
[
5.0
,
2.0
]])
arr
=
np
.
array
([[
1.0
,
4.0
],
[
5.0
,
2.0
]])
compare_pytorch_and_py
(
fgraph
,
[
arr
])
compare_pytorch_and_py
(
[
x
],
[
out
]
,
[
arr
])
tests/link/pytorch/test_subtensor.py
浏览文件 @
cc8c4992
...
@@ -6,7 +6,6 @@ import pytest
...
@@ -6,7 +6,6 @@ import pytest
import
pytensor.scalar
as
ps
import
pytensor.scalar
as
ps
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.fg
import
FunctionGraph
from
pytensor.tensor
import
inc_subtensor
,
set_subtensor
from
pytensor.tensor
import
inc_subtensor
,
set_subtensor
from
pytensor.tensor
import
subtensor
as
pt_subtensor
from
pytensor.tensor
import
subtensor
as
pt_subtensor
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
from
tests.link.pytorch.test_basic
import
compare_pytorch_and_py
...
@@ -19,38 +18,33 @@ def test_pytorch_Subtensor():
...
@@ -19,38 +18,33 @@ def test_pytorch_Subtensor():
out_pt
=
x_pt
[
1
,
2
,
0
]
out_pt
=
x_pt
[
1
,
2
,
0
]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
compare_pytorch_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
out_pt
=
x_pt
[
1
:,
1
,
:]
out_pt
=
x_pt
[
1
:,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[:
2
,
1
,
:]
out_pt
=
x_pt
[:
2
,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[
1
:
2
,
1
,
:]
out_pt
=
x_pt
[
1
:
2
,
1
,
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
# symbolic index
# symbolic index
a_pt
=
ps
.
int64
(
"a"
)
a_pt
=
ps
.
int64
(
"a"
)
a_np
=
1
a_np
=
1
out_pt
=
x_pt
[
a_pt
,
2
,
a_pt
:
2
]
out_pt
=
x_pt
[
a_pt
,
2
,
a_pt
:
2
]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
Subtensor
)
out_fg
=
FunctionGraph
([
x_pt
,
a_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
,
a_pt
],
[
out_pt
],
[
x_np
,
a_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
,
a_np
])
with
pytest
.
raises
(
with
pytest
.
raises
(
NotImplementedError
,
match
=
"Negative step sizes are not supported in Pytorch"
NotImplementedError
,
match
=
"Negative step sizes are not supported in Pytorch"
):
):
out_pt
=
x_pt
[::
-
1
]
out_pt
=
x_pt
[::
-
1
]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
def
test_pytorch_AdvSubtensor
():
def
test_pytorch_AdvSubtensor
():
...
@@ -60,52 +54,43 @@ def test_pytorch_AdvSubtensor():
...
@@ -60,52 +54,43 @@ def test_pytorch_AdvSubtensor():
out_pt
=
pt_subtensor
.
advanced_subtensor1
(
x_pt
,
[
1
,
2
])
out_pt
=
pt_subtensor
.
advanced_subtensor1
(
x_pt
,
[
1
,
2
])
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor1
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor1
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
[
2
,
3
]]
out_pt
=
x_pt
[[
1
,
2
],
[
2
,
3
]]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
1
:]
out_pt
=
x_pt
[[
1
,
2
],
1
:]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
:,
[
3
,
4
]]
out_pt
=
x_pt
[[
1
,
2
],
:,
[
3
,
4
]]
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
out_pt
=
x_pt
[[
1
,
2
],
None
]
out_pt
=
x_pt
[[
1
,
2
],
None
]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
a_pt
=
ps
.
int64
(
"a"
)
a_pt
=
ps
.
int64
(
"a"
)
a_np
=
2
a_np
=
2
out_pt
=
x_pt
[[
1
,
a_pt
],
a_pt
]
out_pt
=
x_pt
[[
1
,
a_pt
],
a_pt
]
out_fg
=
FunctionGraph
([
x_pt
,
a_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
,
a_pt
],
[
out_pt
],
[
x_np
,
a_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
,
a_np
])
# boolean indices
# boolean indices
out_pt
=
x_pt
[
np
.
random
.
binomial
(
1
,
0.5
,
size
=
(
3
,
4
,
5
))
.
astype
(
bool
)]
out_pt
=
x_pt
[
np
.
random
.
binomial
(
1
,
0.5
,
size
=
(
3
,
4
,
5
))
.
astype
(
bool
)]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
a_pt
=
pt
.
tensor3
(
"a"
,
dtype
=
"bool"
)
a_pt
=
pt
.
tensor3
(
"a"
,
dtype
=
"bool"
)
a_np
=
np
.
random
.
binomial
(
1
,
0.5
,
size
=
(
3
,
4
,
5
))
.
astype
(
bool
)
a_np
=
np
.
random
.
binomial
(
1
,
0.5
,
size
=
(
3
,
4
,
5
))
.
astype
(
bool
)
out_pt
=
x_pt
[
a_pt
]
out_pt
=
x_pt
[
a_pt
]
out_fg
=
FunctionGraph
([
x_pt
,
a_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
,
a_pt
],
[
out_pt
],
[
x_np
,
a_np
])
compare_pytorch_and_py
(
out_fg
,
[
x_np
,
a_np
])
with
pytest
.
raises
(
with
pytest
.
raises
(
NotImplementedError
,
match
=
"Negative step sizes are not supported in Pytorch"
NotImplementedError
,
match
=
"Negative step sizes are not supported in Pytorch"
):
):
out_pt
=
x_pt
[[
1
,
2
],
::
-
1
]
out_pt
=
x_pt
[[
1
,
2
],
::
-
1
]
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedSubtensor
)
compare_pytorch_and_py
(
out_fg
,
[
x_np
])
compare_pytorch_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_np
])
@pytest.mark.parametrize
(
"subtensor_op"
,
[
set_subtensor
,
inc_subtensor
])
@pytest.mark.parametrize
(
"subtensor_op"
,
[
set_subtensor
,
inc_subtensor
])
...
@@ -116,20 +101,17 @@ def test_pytorch_IncSubtensor(subtensor_op):
...
@@ -116,20 +101,17 @@ def test_pytorch_IncSubtensor(subtensor_op):
st_pt
=
pt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
config
.
floatX
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
array
(
-
10.0
,
dtype
=
config
.
floatX
))
out_pt
=
subtensor_op
(
x_pt
[
1
,
2
,
3
],
st_pt
)
out_pt
=
subtensor_op
(
x_pt
[
1
,
2
,
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_test
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
# Test different type update
# Test different type update
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
"float32"
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
"float32"
))
out_pt
=
subtensor_op
(
x_pt
[:
2
,
0
,
0
],
st_pt
)
out_pt
=
subtensor_op
(
x_pt
[:
2
,
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_test
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
out_pt
=
subtensor_op
(
x_pt
[
0
,
1
:
3
,
0
],
st_pt
)
out_pt
=
subtensor_op
(
x_pt
[
0
,
1
:
3
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
IncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_test
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
def
inc_subtensor_ignore_duplicates
(
x
,
y
):
def
inc_subtensor_ignore_duplicates
(
x
,
y
):
...
@@ -150,14 +132,12 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
...
@@ -150,14 +132,12 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
)
)
out_pt
=
advsubtensor_op
(
x_pt
[
np
.
r_
[
0
,
2
]],
st_pt
)
out_pt
=
advsubtensor_op
(
x_pt
[
np
.
r_
[
0
,
2
]],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_test
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
# Repeated indices
# Repeated indices
out_pt
=
advsubtensor_op
(
x_pt
[
np
.
r_
[
0
,
0
]],
st_pt
)
out_pt
=
advsubtensor_op
(
x_pt
[
np
.
r_
[
0
,
0
]],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_test
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
# Mixing advanced and basic indexing
# Mixing advanced and basic indexing
if
advsubtensor_op
is
inc_subtensor
:
if
advsubtensor_op
is
inc_subtensor
:
...
@@ -168,19 +148,16 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
...
@@ -168,19 +148,16 @@ def test_pytorch_AvdancedIncSubtensor(advsubtensor_op):
st_pt
=
pt
.
as_tensor_variable
(
x_test
[[
0
,
2
],
0
,
:
3
])
st_pt
=
pt
.
as_tensor_variable
(
x_test
[[
0
,
2
],
0
,
:
3
])
out_pt
=
advsubtensor_op
(
x_pt
[[
0
,
0
],
0
,
:
3
],
st_pt
)
out_pt
=
advsubtensor_op
(
x_pt
[[
0
,
0
],
0
,
:
3
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
with
expectation
:
with
expectation
:
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
compare_pytorch_and_py
(
[
x_pt
],
[
out_pt
]
,
[
x_test
])
# Test different dtype update
# Test different dtype update
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
"float32"
))
st_pt
=
pt
.
as_tensor_variable
(
np
.
r_
[
-
1.0
,
0.0
]
.
astype
(
"float32"
))
out_pt
=
advsubtensor_op
(
x_pt
[[
0
,
2
],
0
,
0
],
st_pt
)
out_pt
=
advsubtensor_op
(
x_pt
[[
0
,
2
],
0
,
0
],
st_pt
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_test
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
# Boolean indices
# Boolean indices
out_pt
=
advsubtensor_op
(
x_pt
[
x_pt
>
5
],
1.0
)
out_pt
=
advsubtensor_op
(
x_pt
[
x_pt
>
5
],
1.0
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
assert
isinstance
(
out_pt
.
owner
.
op
,
pt_subtensor
.
AdvancedIncSubtensor
)
out_fg
=
FunctionGraph
([
x_pt
],
[
out_pt
])
compare_pytorch_and_py
([
x_pt
],
[
out_pt
],
[
x_test
])
compare_pytorch_and_py
(
out_fg
,
[
x_test
])
tests/tensor/test_extra_ops.py
浏览文件 @
cc8c4992
...
@@ -63,11 +63,6 @@ from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
...
@@ -63,11 +63,6 @@ from pytensor.utils import LOCAL_BITWIDTH, PYTHON_INT_BITWIDTH
from
tests
import
unittest_tools
as
utt
from
tests
import
unittest_tools
as
utt
def
set_test_value
(
x
,
v
):
x
.
tag
.
test_value
=
v
return
x
def
test_cpu_contiguous
():
def
test_cpu_contiguous
():
a
=
fmatrix
(
"a"
)
a
=
fmatrix
(
"a"
)
i
=
iscalar
(
"i"
)
i
=
iscalar
(
"i"
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论