Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
bf628c97
Unverified
提交
bf628c97
authored
3月 04, 2025
作者:
Jesse Grabowski
提交者:
GitHub
3月 04, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Allow `transposed` argument in `linalg.solve` (#1231)
* Add transposed argument to `solve` and `solve_triangular` * Expand test coverage for `Solve` and `SolveTriangular`
上级
757a10cd
显示空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
210 行增加
和
111 行删除
+210
-111
slinalg.py
pytensor/link/jax/dispatch/slinalg.py
+1
-2
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+1
-2
slinalg.py
pytensor/tensor/slinalg.py
+27
-15
test_slinalg.py
tests/link/jax/test_slinalg.py
+32
-18
test_slinalg.py
tests/link/numba/test_slinalg.py
+15
-17
test_slinalg.py
tests/tensor/test_slinalg.py
+134
-57
没有找到文件。
pytensor/link/jax/dispatch/slinalg.py
浏览文件 @
bf628c97
...
@@ -53,7 +53,6 @@ def jax_funcify_Solve(op, **kwargs):
...
@@ -53,7 +53,6 @@ def jax_funcify_Solve(op, **kwargs):
@jax_funcify.register
(
SolveTriangular
)
@jax_funcify.register
(
SolveTriangular
)
def
jax_funcify_SolveTriangular
(
op
,
**
kwargs
):
def
jax_funcify_SolveTriangular
(
op
,
**
kwargs
):
lower
=
op
.
lower
lower
=
op
.
lower
trans
=
op
.
trans
unit_diagonal
=
op
.
unit_diagonal
unit_diagonal
=
op
.
unit_diagonal
check_finite
=
op
.
check_finite
check_finite
=
op
.
check_finite
...
@@ -62,7 +61,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
...
@@ -62,7 +61,7 @@ def jax_funcify_SolveTriangular(op, **kwargs):
A
,
A
,
b
,
b
,
lower
=
lower
,
lower
=
lower
,
trans
=
trans
,
trans
=
0
,
# this is handled by explicitly transposing A, so it will always be 0 when we get to here.
unit_diagonal
=
unit_diagonal
,
unit_diagonal
=
unit_diagonal
,
check_finite
=
check_finite
,
check_finite
=
check_finite
,
)
)
...
...
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
bf628c97
...
@@ -180,7 +180,6 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
...
@@ -180,7 +180,6 @@ def solve_triangular_impl(A, B, trans, lower, unit_diagonal, b_ndim, overwrite_b
@numba_funcify.register
(
SolveTriangular
)
@numba_funcify.register
(
SolveTriangular
)
def
numba_funcify_SolveTriangular
(
op
,
node
,
**
kwargs
):
def
numba_funcify_SolveTriangular
(
op
,
node
,
**
kwargs
):
trans
=
bool
(
op
.
trans
)
lower
=
op
.
lower
lower
=
op
.
lower
unit_diagonal
=
op
.
unit_diagonal
unit_diagonal
=
op
.
unit_diagonal
check_finite
=
op
.
check_finite
check_finite
=
op
.
check_finite
...
@@ -208,7 +207,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
...
@@ -208,7 +207,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
res
=
_solve_triangular
(
res
=
_solve_triangular
(
a
,
a
,
b
,
b
,
trans
=
trans
,
trans
=
0
,
# transposing is handled explicitly on the graph, so we never use this argument
lower
=
lower
,
lower
=
lower
,
unit_diagonal
=
unit_diagonal
,
unit_diagonal
=
unit_diagonal
,
overwrite_b
=
overwrite_b
,
overwrite_b
=
overwrite_b
,
...
...
pytensor/tensor/slinalg.py
浏览文件 @
bf628c97
...
@@ -296,13 +296,12 @@ class SolveBase(Op):
...
@@ -296,13 +296,12 @@ class SolveBase(Op):
# We need to return (dC/d[inv(A)], dC/db)
# We need to return (dC/d[inv(A)], dC/db)
c_bar
=
output_gradients
[
0
]
c_bar
=
output_gradients
[
0
]
trans_solve_op
=
type
(
self
)(
props_dict
=
self
.
_props_dict
()
**
{
props_dict
[
"lower"
]
=
not
self
.
lower
k
:
(
not
getattr
(
self
,
k
)
if
k
==
"lower"
else
getattr
(
self
,
k
))
for
k
in
self
.
__props__
solve_op
=
type
(
self
)(
**
props_dict
)
}
)
b_bar
=
solve_op
(
A
.
T
,
c_bar
)
b_bar
=
trans_solve_op
(
A
.
T
,
c_bar
)
# force outer product if vector second input
# force outer product if vector second input
A_bar
=
-
ptm
.
outer
(
b_bar
,
c
)
if
c
.
ndim
==
1
else
-
b_bar
.
dot
(
c
.
T
)
A_bar
=
-
ptm
.
outer
(
b_bar
,
c
)
if
c
.
ndim
==
1
else
-
b_bar
.
dot
(
c
.
T
)
...
@@ -385,7 +384,6 @@ class SolveTriangular(SolveBase):
...
@@ -385,7 +384,6 @@ class SolveTriangular(SolveBase):
"""Solve a system of linear equations."""
"""Solve a system of linear equations."""
__props__
=
(
__props__
=
(
"trans"
,
"unit_diagonal"
,
"unit_diagonal"
,
"lower"
,
"lower"
,
"check_finite"
,
"check_finite"
,
...
@@ -393,11 +391,10 @@ class SolveTriangular(SolveBase):
...
@@ -393,11 +391,10 @@ class SolveTriangular(SolveBase):
"overwrite_b"
,
"overwrite_b"
,
)
)
def
__init__
(
self
,
*
,
trans
=
0
,
unit_diagonal
=
False
,
**
kwargs
):
def
__init__
(
self
,
*
,
unit_diagonal
=
False
,
**
kwargs
):
if
kwargs
.
get
(
"overwrite_a"
,
False
):
if
kwargs
.
get
(
"overwrite_a"
,
False
):
raise
ValueError
(
"overwrite_a is not supported for SolverTriangulare"
)
raise
ValueError
(
"overwrite_a is not supported for SolverTriangulare"
)
super
()
.
__init__
(
**
kwargs
)
super
()
.
__init__
(
**
kwargs
)
self
.
trans
=
trans
self
.
unit_diagonal
=
unit_diagonal
self
.
unit_diagonal
=
unit_diagonal
def
perform
(
self
,
node
,
inputs
,
outputs
):
def
perform
(
self
,
node
,
inputs
,
outputs
):
...
@@ -406,7 +403,7 @@ class SolveTriangular(SolveBase):
...
@@ -406,7 +403,7 @@ class SolveTriangular(SolveBase):
A
,
A
,
b
,
b
,
lower
=
self
.
lower
,
lower
=
self
.
lower
,
trans
=
self
.
trans
,
trans
=
0
,
unit_diagonal
=
self
.
unit_diagonal
,
unit_diagonal
=
self
.
unit_diagonal
,
check_finite
=
self
.
check_finite
,
check_finite
=
self
.
check_finite
,
overwrite_b
=
self
.
overwrite_b
,
overwrite_b
=
self
.
overwrite_b
,
...
@@ -445,9 +442,9 @@ def solve_triangular(
...
@@ -445,9 +442,9 @@ def solve_triangular(
Parameters
Parameters
----------
----------
a
a
: TensorVariable
Square input data
Square input data
b
b
: TensorVariable
Input data for the right hand side.
Input data for the right hand side.
lower : bool, optional
lower : bool, optional
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
Use only data contained in the lower triangle of `a`. Default is to use upper triangle.
...
@@ -468,10 +465,17 @@ def solve_triangular(
...
@@ -468,10 +465,17 @@ def solve_triangular(
This will influence how batched dimensions are interpreted.
This will influence how batched dimensions are interpreted.
"""
"""
b_ndim
=
_default_b_ndim
(
b
,
b_ndim
)
b_ndim
=
_default_b_ndim
(
b
,
b_ndim
)
if
trans
in
[
1
,
"T"
,
True
]:
a
=
a
.
mT
lower
=
not
lower
if
trans
in
[
2
,
"C"
]:
a
=
a
.
conj
()
.
mT
lower
=
not
lower
ret
=
Blockwise
(
ret
=
Blockwise
(
SolveTriangular
(
SolveTriangular
(
lower
=
lower
,
lower
=
lower
,
trans
=
trans
,
unit_diagonal
=
unit_diagonal
,
unit_diagonal
=
unit_diagonal
,
check_finite
=
check_finite
,
check_finite
=
check_finite
,
b_ndim
=
b_ndim
,
b_ndim
=
b_ndim
,
...
@@ -534,6 +538,7 @@ def solve(
...
@@ -534,6 +538,7 @@ def solve(
*
,
*
,
assume_a
=
"gen"
,
assume_a
=
"gen"
,
lower
=
False
,
lower
=
False
,
transposed
=
False
,
check_finite
=
True
,
check_finite
=
True
,
b_ndim
:
int
|
None
=
None
,
b_ndim
:
int
|
None
=
None
,
):
):
...
@@ -564,8 +569,10 @@ def solve(
...
@@ -564,8 +569,10 @@ def solve(
b : (..., N, NRHS) array_like
b : (..., N, NRHS) array_like
Input data for the right hand side.
Input data for the right hand side.
lower : bool, optional
lower : bool, optional
If True, only the data contained in the lower triangle of `a`. Default
If True,
use
only the data contained in the lower triangle of `a`. Default
is to use upper triangle. (ignored for ``'gen'``)
is to use upper triangle. (ignored for ``'gen'``)
transposed: bool, optional
If True, solves the system A^T x = b. Default is False.
check_finite : bool, optional
check_finite : bool, optional
Whether to check that the input matrices contain only finite numbers.
Whether to check that the input matrices contain only finite numbers.
Disabling may give a performance gain, but may result in problems
Disabling may give a performance gain, but may result in problems
...
@@ -577,6 +584,11 @@ def solve(
...
@@ -577,6 +584,11 @@ def solve(
This will influence how batched dimensions are interpreted.
This will influence how batched dimensions are interpreted.
"""
"""
b_ndim
=
_default_b_ndim
(
b
,
b_ndim
)
b_ndim
=
_default_b_ndim
(
b
,
b_ndim
)
if
transposed
:
a
=
a
.
mT
lower
=
not
lower
return
Blockwise
(
return
Blockwise
(
Solve
(
Solve
(
lower
=
lower
,
lower
=
lower
,
...
...
tests/link/jax/test_slinalg.py
浏览文件 @
bf628c97
...
@@ -5,6 +5,7 @@ import numpy as np
...
@@ -5,6 +5,7 @@ import numpy as np
import
pytest
import
pytest
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
import
tests.unittest_tools
as
utt
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor
import
nlinalg
as
pt_nlinalg
from
pytensor.tensor
import
slinalg
as
pt_slinalg
from
pytensor.tensor
import
slinalg
as
pt_slinalg
...
@@ -103,28 +104,41 @@ def test_jax_basic():
...
@@ -103,28 +104,41 @@ def test_jax_basic():
)
)
@pytest.mark.parametrize
(
"check_finite"
,
[
False
,
True
])
def
test_jax_solve
():
@pytest.mark.parametrize
(
"lower"
,
[
False
,
True
])
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
@pytest.mark.parametrize
(
"trans"
,
[
0
,
1
,
2
])
def
test_jax_SolveTriangular
(
trans
,
lower
,
check_finite
):
A
=
pt
.
tensor
(
"A"
,
shape
=
(
5
,
5
))
x
=
matrix
(
"x"
)
b
=
pt
.
tensor
(
"B"
,
shape
=
(
5
,
5
))
b
=
vector
(
"b"
)
out
=
pt_slinalg
.
solve
(
A
,
b
,
lower
=
False
,
transposed
=
False
)
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
b_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
out
=
pt_slinalg
.
solve_triangular
(
x
,
b
,
trans
=
trans
,
lower
=
lower
,
check_finite
=
check_finite
,
)
compare_jax_and_py
(
compare_jax_and_py
(
[
x
,
b
],
[
A
,
b
],
[
out
],
[
out
],
[
[
A_val
,
b_val
],
np
.
eye
(
10
)
.
astype
(
config
.
floatX
),
)
np
.
arange
(
10
)
.
astype
(
config
.
floatX
),
],
def
test_jax_SolveTriangular
():
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
pt
.
tensor
(
"A"
,
shape
=
(
5
,
5
))
b
=
pt
.
tensor
(
"B"
,
shape
=
(
5
,
5
))
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
b_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
out
=
pt_slinalg
.
solve_triangular
(
A
,
b
,
trans
=
0
,
lower
=
True
,
unit_diagonal
=
False
,
)
)
compare_jax_and_py
([
A
,
b
],
[
out
],
[
A_val
,
b_val
])
def
test_jax_block_diag
():
def
test_jax_block_diag
():
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
bf628c97
...
@@ -5,7 +5,6 @@ from typing import Literal
...
@@ -5,7 +5,6 @@ from typing import Literal
import
numpy
as
np
import
numpy
as
np
import
pytest
import
pytest
from
numpy.testing
import
assert_allclose
from
numpy.testing
import
assert_allclose
from
scipy
import
linalg
as
scipy_linalg
import
pytensor
import
pytensor
import
pytensor.tensor
as
pt
import
pytensor.tensor
as
pt
...
@@ -26,9 +25,9 @@ def transpose_func(x, trans):
...
@@ -26,9 +25,9 @@ def transpose_func(x, trans):
if
trans
==
0
:
if
trans
==
0
:
return
x
return
x
if
trans
==
1
:
if
trans
==
1
:
return
x
.
conj
()
.
T
if
trans
==
2
:
return
x
.
T
return
x
.
T
if
trans
==
2
:
return
x
.
conj
()
.
T
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
...
@@ -59,18 +58,18 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
...
@@ -59,18 +58,18 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
def
A_func
(
x
):
def
A_func
(
x
):
x
=
x
@
x
.
conj
()
.
T
x
=
x
@
x
.
conj
()
.
T
x_tri
=
scipy_
linalg
.
cholesky
(
x
,
lower
=
lower
)
.
astype
(
dtype
)
x_tri
=
pt
.
linalg
.
cholesky
(
x
,
lower
=
lower
)
.
astype
(
dtype
)
if
unit_diag
:
if
unit_diag
:
x_tri
[
np
.
diag_indices_from
(
x_tri
)]
=
1.0
x_tri
=
pt
.
fill_diagonal
(
x_tri
,
1.0
)
return
x_tri
.
astype
(
dtype
)
return
x_tri
solve_op
=
partial
(
solve_op
=
partial
(
pt
.
linalg
.
solve_triangular
,
lower
=
lower
,
trans
=
trans
,
unit_diagonal
=
unit_diag
pt
.
linalg
.
solve_triangular
,
lower
=
lower
,
trans
=
trans
,
unit_diagonal
=
unit_diag
)
)
X
=
solve_op
(
A
,
b
)
X
=
solve_op
(
A
_func
(
A
)
,
b
)
f
=
pytensor
.
function
([
A
,
b
],
X
,
mode
=
"NUMBA"
)
f
=
pytensor
.
function
([
A
,
b
],
X
,
mode
=
"NUMBA"
)
A_val
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
A_val
=
np
.
random
.
normal
(
size
=
(
5
,
5
))
...
@@ -80,20 +79,20 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
...
@@ -80,20 +79,20 @@ def test_solve_triangular(b_shape: tuple[int], lower, trans, unit_diag, is_compl
A_val
=
A_val
+
np
.
random
.
normal
(
size
=
(
5
,
5
))
*
1
j
A_val
=
A_val
+
np
.
random
.
normal
(
size
=
(
5
,
5
))
*
1
j
b_val
=
b_val
+
np
.
random
.
normal
(
size
=
b_shape
)
*
1
j
b_val
=
b_val
+
np
.
random
.
normal
(
size
=
b_shape
)
*
1
j
X_np
=
f
(
A_
func
(
A_val
),
b_val
)
X_np
=
f
(
A_
val
.
copy
(),
b_val
.
copy
()
)
A_val_transformed
=
transpose_func
(
A_func
(
A_val
),
trans
)
.
eval
()
test_input
=
transpose_func
(
A_func
(
A_val
),
trans
)
np
.
testing
.
assert_allclose
(
A_val_transformed
@
X_np
,
ATOL
=
1e-8
if
floatX
.
endswith
(
"64"
)
else
1e-4
b_val
,
RTOL
=
1e-8
if
floatX
.
endswith
(
"64"
)
else
1e-4
atol
=
1e-8
if
floatX
.
endswith
(
"64"
)
else
1e-4
,
rtol
=
1e-8
if
floatX
.
endswith
(
"64"
)
else
1e-4
,
np
.
testing
.
assert_allclose
(
test_input
@
X_np
,
b_val
,
atol
=
ATOL
,
rtol
=
RTOL
)
)
compiled_fgraph
=
f
.
maker
.
fgraph
compiled_fgraph
=
f
.
maker
.
fgraph
compare_numba_and_py
(
compare_numba_and_py
(
compiled_fgraph
.
inputs
,
compiled_fgraph
.
inputs
,
compiled_fgraph
.
outputs
,
compiled_fgraph
.
outputs
,
[
A_
func
(
A_val
)
,
b_val
],
[
A_
val
,
b_val
],
)
)
...
@@ -145,7 +144,6 @@ def test_solve_triangular_overwrite_b_correct(overwrite_b):
...
@@ -145,7 +144,6 @@ def test_solve_triangular_overwrite_b_correct(overwrite_b):
b_test_nb
=
b_test_py
.
copy
(
order
=
"F"
)
b_test_nb
=
b_test_py
.
copy
(
order
=
"F"
)
op
=
SolveTriangular
(
op
=
SolveTriangular
(
trans
=
0
,
unit_diagonal
=
False
,
unit_diagonal
=
False
,
lower
=
False
,
lower
=
False
,
check_finite
=
True
,
check_finite
=
True
,
...
...
tests/tensor/test_slinalg.py
浏览文件 @
bf628c97
...
@@ -214,7 +214,38 @@ def test_solve_raises_on_invalid_A():
...
@@ -214,7 +214,38 @@ def test_solve_raises_on_invalid_A():
Solve
(
assume_a
=
"test"
,
b_ndim
=
2
)
Solve
(
assume_a
=
"test"
,
b_ndim
=
2
)
solve_test_cases
=
[
(
"gen"
,
False
,
False
),
(
"gen"
,
False
,
True
),
(
"sym"
,
False
,
False
),
(
"sym"
,
True
,
False
),
(
"sym"
,
True
,
True
),
(
"pos"
,
False
,
False
),
(
"pos"
,
True
,
False
),
(
"pos"
,
True
,
True
),
]
solve_test_ids
=
[
f
'{assume_a}_{"lower" if lower else "upper"}_{"A^T" if transposed else "A"}'
for
assume_a
,
lower
,
transposed
in
solve_test_cases
]
class
TestSolve
(
utt
.
InferShapeTester
):
class
TestSolve
(
utt
.
InferShapeTester
):
@staticmethod
def
A_func
(
x
,
assume_a
):
if
assume_a
==
"pos"
:
return
x
@
x
.
T
elif
assume_a
==
"sym"
:
return
(
x
+
x
.
T
)
/
2
else
:
return
x
@staticmethod
def
T
(
x
,
transposed
):
if
transposed
:
return
x
.
T
return
x
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,
1
),
(
5
,)])
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,
1
),
(
5
,)])
def
test_infer_shape
(
self
,
b_shape
):
def
test_infer_shape
(
self
,
b_shape
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
...
@@ -235,8 +266,12 @@ class TestSolve(utt.InferShapeTester):
...
@@ -235,8 +266,12 @@ class TestSolve(utt.InferShapeTester):
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"b_size"
,
[(
5
,
1
),
(
5
,
5
),
(
5
,)],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
]
"b_size"
,
[(
5
,
1
),
(
5
,
5
),
(
5
,)],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
]
)
)
@pytest.mark.parametrize
(
"assume_a"
,
[
"gen"
,
"sym"
,
"pos"
],
ids
=
str
)
@pytest.mark.parametrize
(
def
test_solve_correctness
(
self
,
b_size
:
tuple
[
int
],
assume_a
:
str
):
"assume_a, lower, transposed"
,
solve_test_cases
,
ids
=
solve_test_ids
)
def
test_solve_correctness
(
self
,
b_size
:
tuple
[
int
],
assume_a
:
str
,
lower
:
bool
,
transposed
:
bool
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
pt
.
tensor
(
"A"
,
shape
=
(
5
,
5
))
A
=
pt
.
tensor
(
"A"
,
shape
=
(
5
,
5
))
b
=
pt
.
tensor
(
"b"
,
shape
=
b_size
)
b
=
pt
.
tensor
(
"b"
,
shape
=
b_size
)
...
@@ -244,19 +279,18 @@ class TestSolve(utt.InferShapeTester):
...
@@ -244,19 +279,18 @@ class TestSolve(utt.InferShapeTester):
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
b_val
=
rng
.
normal
(
size
=
b_size
)
.
astype
(
config
.
floatX
)
b_val
=
rng
.
normal
(
size
=
b_size
)
.
astype
(
config
.
floatX
)
solve_op
=
functools
.
partial
(
solve
,
assume_a
=
assume_a
,
b_ndim
=
len
(
b_size
))
A_func
=
functools
.
partial
(
self
.
A_func
,
assume_a
=
assume_a
)
T
=
functools
.
partial
(
self
.
T
,
transposed
=
transposed
)
def
A_func
(
x
):
y
=
solve
(
if
assume_a
==
"pos"
:
A_func
(
A
),
return
x
@
x
.
T
b
,
elif
assume_a
==
"sym"
:
assume_a
=
assume_a
,
return
(
x
+
x
.
T
)
/
2
lower
=
lower
,
else
:
transposed
=
transposed
,
return
x
b_ndim
=
len
(
b_size
),
)
solve_input_val
=
A_func
(
A_val
)
y
=
solve_op
(
A_func
(
A
),
b
)
solve_func
=
pytensor
.
function
([
A
,
b
],
y
)
solve_func
=
pytensor
.
function
([
A
,
b
],
y
)
X_np
=
solve_func
(
A_val
.
copy
(),
b_val
.
copy
())
X_np
=
solve_func
(
A_val
.
copy
(),
b_val
.
copy
())
...
@@ -264,22 +298,34 @@ class TestSolve(utt.InferShapeTester):
...
@@ -264,22 +298,34 @@ class TestSolve(utt.InferShapeTester):
RTOL
=
1e-8
if
config
.
floatX
.
endswith
(
"64"
)
else
1e-4
RTOL
=
1e-8
if
config
.
floatX
.
endswith
(
"64"
)
else
1e-4
np
.
testing
.
assert_allclose
(
np
.
testing
.
assert_allclose
(
scipy
.
linalg
.
solve
(
solve_input_val
,
b_val
,
assume_a
=
assume_a
),
scipy
.
linalg
.
solve
(
A_func
(
A_val
),
b_val
,
assume_a
=
assume_a
,
transposed
=
transposed
,
lower
=
lower
,
),
X_np
,
X_np
,
atol
=
ATOL
,
atol
=
ATOL
,
rtol
=
RTOL
,
rtol
=
RTOL
,
)
)
np
.
testing
.
assert_allclose
(
A_func
(
A_val
)
@
X_np
,
b_val
,
atol
=
ATOL
,
rtol
=
RTOL
)
np
.
testing
.
assert_allclose
(
T
(
A_func
(
A_val
)
)
@
X_np
,
b_val
,
atol
=
ATOL
,
rtol
=
RTOL
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"b_size"
,
[(
5
,
1
),
(
5
,
5
),
(
5
,)],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
]
"b_size"
,
[(
5
,
1
),
(
5
,
5
),
(
5
,)],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
]
)
)
@pytest.mark.parametrize
(
"assume_a"
,
[
"gen"
,
"sym"
,
"pos"
],
ids
=
str
)
@pytest.mark.parametrize
(
"assume_a, lower, transposed"
,
solve_test_cases
,
ids
=
solve_test_ids
,
)
@pytest.mark.skipif
(
@pytest.mark.skipif
(
config
.
floatX
==
"float32"
,
reason
=
"Gradients not numerically stable in float32"
config
.
floatX
==
"float32"
,
reason
=
"Gradients not numerically stable in float32"
)
)
def
test_solve_gradient
(
self
,
b_size
:
tuple
[
int
],
assume_a
:
str
):
def
test_solve_gradient
(
self
,
b_size
:
tuple
[
int
],
assume_a
:
str
,
lower
:
bool
,
transposed
:
bool
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
eps
=
2e-8
if
config
.
floatX
==
"float64"
else
None
eps
=
2e-8
if
config
.
floatX
==
"float64"
else
None
...
@@ -287,15 +333,8 @@ class TestSolve(utt.InferShapeTester):
...
@@ -287,15 +333,8 @@ class TestSolve(utt.InferShapeTester):
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
b_val
=
rng
.
normal
(
size
=
b_size
)
.
astype
(
config
.
floatX
)
b_val
=
rng
.
normal
(
size
=
b_size
)
.
astype
(
config
.
floatX
)
def
A_func
(
x
):
if
assume_a
==
"pos"
:
return
x
@
x
.
T
elif
assume_a
==
"sym"
:
return
(
x
+
x
.
T
)
/
2
else
:
return
x
solve_op
=
functools
.
partial
(
solve
,
assume_a
=
assume_a
,
b_ndim
=
len
(
b_size
))
solve_op
=
functools
.
partial
(
solve
,
assume_a
=
assume_a
,
b_ndim
=
len
(
b_size
))
A_func
=
functools
.
partial
(
self
.
A_func
,
assume_a
=
assume_a
)
# To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices
# To correctly check the gradients, we need to include a transformation from the space of unconstrained matrices
# (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included,
# (A) to a valid input matrix for the given solver. This is done by the A_func function. If this isn't included,
...
@@ -307,11 +346,27 @@ class TestSolve(utt.InferShapeTester):
...
@@ -307,11 +346,27 @@ class TestSolve(utt.InferShapeTester):
class
TestSolveTriangular
(
utt
.
InferShapeTester
):
class
TestSolveTriangular
(
utt
.
InferShapeTester
):
@staticmethod
def
A_func
(
x
,
lower
,
unit_diagonal
):
x
=
x
@
x
.
T
x
=
pt
.
linalg
.
cholesky
(
x
,
lower
=
lower
)
if
unit_diagonal
:
x
=
pt
.
fill_diagonal
(
x
,
1
)
return
x
@staticmethod
def
T
(
x
,
trans
):
if
trans
==
1
:
return
x
.
T
elif
trans
==
2
:
return
x
.
conj
()
.
T
return
x
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,
1
),
(
5
,)])
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,
1
),
(
5
,)])
def
test_infer_shape
(
self
,
b_shape
):
def
test_infer_shape
(
self
,
b_shape
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
matrix
()
A
=
matrix
()
b_val
=
np
.
asarray
(
rng
.
random
(
b_shape
),
dtype
=
config
.
floatX
)
b_val
=
rng
.
random
(
b_shape
)
.
astype
(
config
.
floatX
)
b
=
pt
.
as_tensor_variable
(
b_val
)
.
type
()
b
=
pt
.
as_tensor_variable
(
b_val
)
.
type
()
self
.
_compile_and_check
(
self
.
_compile_and_check
(
[
A
,
b
],
[
A
,
b
],
...
@@ -324,56 +379,78 @@ class TestSolveTriangular(utt.InferShapeTester):
...
@@ -324,56 +379,78 @@ class TestSolveTriangular(utt.InferShapeTester):
warn
=
False
,
warn
=
False
,
)
)
@pytest.mark.parametrize
(
"b_shape"
,
[(
5
,
1
),
(
5
,),
(
5
,
5
)],
ids
=
[
"b_col_vec"
,
"b_vec"
,
"b_matrix"
]
)
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
])
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
])
def
test_correctness
(
self
,
lower
):
@pytest.mark.parametrize
(
"trans"
,
[
0
,
1
,
2
])
@pytest.mark.parametrize
(
"unit_diagonal"
,
[
True
,
False
])
def
test_correctness
(
self
,
b_shape
:
tuple
[
int
],
lower
,
trans
,
unit_diagonal
):
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
A
=
pt
.
tensor
(
"A"
,
shape
=
(
5
,
5
))
b
=
pt
.
tensor
(
"b"
,
shape
=
b_shape
)
b_val
=
np
.
asarray
(
rng
.
random
((
5
,
1
)),
dtype
=
config
.
floatX
)
A_val
=
rng
.
random
((
5
,
5
))
.
astype
(
config
.
floatX
)
b_val
=
rng
.
random
(
b_shape
)
.
astype
(
config
.
floatX
)
A_val
=
np
.
asarray
(
rng
.
random
((
5
,
5
)),
dtype
=
config
.
floatX
)
A_func
=
functools
.
partial
(
A_val
=
np
.
dot
(
A_val
.
transpose
(),
A_val
)
self
.
A_func
,
lower
=
lower
,
unit_diagonal
=
unit_diagonal
)
C_val
=
scipy
.
linalg
.
cholesky
(
A_val
,
lower
=
lower
)
x
=
solve_triangular
(
A_func
(
A
),
b
,
lower
=
lower
,
trans
=
trans
,
unit_diagonal
=
unit_diagonal
,
b_ndim
=
len
(
b_shape
),
)
A
=
matrix
()
f
=
pytensor
.
function
([
A
,
b
],
x
)
b
=
matrix
()
cholesky
=
Cholesky
(
lower
=
lower
)
x_pt
=
f
(
A_val
,
b_val
)
C
=
cholesky
(
A
)
x_sp
=
scipy
.
linalg
.
solve_triangular
(
y_lower
=
solve_triangular
(
C
,
b
,
lower
=
lower
)
A_func
(
A_val
)
.
eval
(),
lower_solve_func
=
pytensor
.
function
([
C
,
b
],
y_lower
)
b_val
,
lower
=
lower
,
trans
=
trans
,
unit_diagonal
=
unit_diagonal
,
)
assert
np
.
allclose
(
np
.
testing
.
assert_allclose
(
scipy
.
linalg
.
solve_triangular
(
C_val
,
b_val
,
lower
=
lower
),
x_pt
,
lower_solve_func
(
C_val
,
b_val
),
x_sp
,
atol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-4
,
rtol
=
1e-8
if
config
.
floatX
==
"float64"
else
1e-4
,
)
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"m, n, lower"
,
"b_shape"
,
[(
5
,
1
),
(
5
,),
(
5
,
5
)],
ids
=
[
"b_col_vec"
,
"b_vec"
,
"b_matrix"
]
[
(
5
,
None
,
False
),
(
5
,
None
,
True
),
(
4
,
2
,
False
),
(
4
,
2
,
True
),
],
)
)
def
test_solve_grad
(
self
,
m
,
n
,
lower
):
@pytest.mark.parametrize
(
"lower"
,
[
True
,
False
])
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
@pytest.mark.parametrize
(
"trans"
,
[
0
,
1
])
@pytest.mark.parametrize
(
"unit_diagonal"
,
[
True
,
False
])
def
test_solve_triangular_grad
(
self
,
b_shape
,
lower
,
trans
,
unit_diagonal
):
if
config
.
floatX
==
"float32"
:
pytest
.
skip
(
reason
=
"Not enough precision in float32 to get a good gradient"
)
# Ensure diagonal elements of `A` are relatively large to avoid
rng
=
np
.
random
.
default_rng
(
utt
.
fetch_seed
())
# numerical precision issues
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
config
.
floatX
)
A_val
=
(
rng
.
normal
(
size
=
(
m
,
m
))
*
0.5
+
np
.
eye
(
m
)
)
.
astype
(
config
.
floatX
)
b_val
=
rng
.
normal
(
size
=
b_shape
)
.
astype
(
config
.
floatX
)
if
n
is
None
:
A_func
=
functools
.
partial
(
b_val
=
rng
.
normal
(
size
=
m
)
.
astype
(
config
.
floatX
)
self
.
A_func
,
lower
=
lower
,
unit_diagonal
=
unit_diagonal
else
:
)
b_val
=
rng
.
normal
(
size
=
(
m
,
n
))
.
astype
(
config
.
floatX
)
eps
=
None
eps
=
None
if
config
.
floatX
==
"float64"
:
if
config
.
floatX
==
"float64"
:
eps
=
2e-8
eps
=
2e-8
solve_op
=
SolveTriangular
(
lower
=
lower
,
b_ndim
=
1
if
n
is
None
else
2
)
def
solve_op
(
A
,
b
):
return
solve_triangular
(
A_func
(
A
),
b
,
lower
=
lower
,
trans
=
trans
,
unit_diagonal
=
unit_diagonal
)
utt
.
verify_grad
(
solve_op
,
[
A_val
,
b_val
],
3
,
rng
,
eps
=
eps
)
utt
.
verify_grad
(
solve_op
,
[
A_val
,
b_val
],
3
,
rng
,
eps
=
eps
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论