Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
a038c8ee
提交
a038c8ee
authored
3月 21, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
3月 27, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement numba tridiagonal solve
上级
19023545
显示空白字符变更
内嵌
并排
正在显示
3 个修改的文件
包含
320 行增加
和
3 行删除
+320
-3
tridiagonal.py
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
+299
-0
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+4
-1
test_slinalg.py
tests/link/numba/test_slinalg.py
+17
-2
没有找到文件。
pytensor/link/numba/dispatch/linalg/solve/tridiagonal.py
0 → 100644
浏览文件 @
a038c8ee
from
collections.abc
import
Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.np.linalg
import
ensure_lapack
from
numpy
import
ndarray
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.basic
import
numba_njit
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
)
@numba_njit
def
tridiagonal_norm
(
du
,
d
,
dl
):
# Adapted from scipy _matrix_norm_tridiagonal:
# https://github.com/scipy/scipy/blob/0f1fd4a7268b813fa2b844ca6038e4dfdf90084a/scipy/linalg/_basic.py#L356-L367
anorm
=
np
.
abs
(
d
)
anorm
[
1
:]
+=
np
.
abs
(
du
)
anorm
[:
-
1
]
+=
np
.
abs
(
dl
)
anorm
=
anorm
.
max
()
return
anorm
def
_gttrf
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
)
->
tuple
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
int
]:
"""Placeholder for LU factorization of tridiagonal matrix."""
return
# type: ignore
@overload
(
_gttrf
)
def
gttrf_impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
)
->
Callable
[
[
ndarray
,
ndarray
,
ndarray
],
tuple
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
int
]
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
dl
,
"gttrf"
)
_check_scipy_linalg_matrix
(
d
,
"gttrf"
)
_check_scipy_linalg_matrix
(
du
,
"gttrf"
)
dtype
=
d
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_gttrf
=
_LAPACK
()
.
numba_xgttrf
(
dtype
)
def
impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
)
->
tuple
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
int
]:
n
=
np
.
int32
(
d
.
shape
[
-
1
])
ipiv
=
np
.
empty
(
n
,
dtype
=
np
.
int32
)
du2
=
np
.
empty
(
n
-
2
,
dtype
=
dtype
)
info
=
val_to_int_ptr
(
0
)
numba_gttrf
(
val_to_int_ptr
(
n
),
dl
.
view
(
w_type
)
.
ctypes
,
d
.
view
(
w_type
)
.
ctypes
,
du
.
view
(
w_type
)
.
ctypes
,
du2
.
view
(
w_type
)
.
ctypes
,
ipiv
.
ctypes
,
info
,
)
return
dl
,
d
,
du
,
du2
,
ipiv
,
int_ptr_to_val
(
info
)
return
impl
def
_gttrs
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
b
:
ndarray
,
overwrite_b
:
bool
,
trans
:
bool
,
)
->
tuple
[
ndarray
,
int
]:
"""Placeholder for solving an LU-decomposed tridiagonal system."""
return
# type: ignore
@overload
(
_gttrs
)
def
gttrs_impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
b
:
ndarray
,
overwrite_b
:
bool
,
trans
:
bool
,
)
->
Callable
[
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
bool
,
bool
],
tuple
[
ndarray
,
int
],
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
dl
,
"gttrs"
)
_check_scipy_linalg_matrix
(
d
,
"gttrs"
)
_check_scipy_linalg_matrix
(
du
,
"gttrs"
)
_check_scipy_linalg_matrix
(
du2
,
"gttrs"
)
_check_scipy_linalg_matrix
(
b
,
"gttrs"
)
dtype
=
d
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_gttrs
=
_LAPACK
()
.
numba_xgttrs
(
dtype
)
def
impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
b
:
ndarray
,
overwrite_b
:
bool
,
trans
:
bool
,
)
->
tuple
[
ndarray
,
int
]:
n
=
np
.
int32
(
d
.
shape
[
-
1
])
nrhs
=
1
if
b
.
ndim
==
1
else
int
(
b
.
shape
[
-
1
])
info
=
val_to_int_ptr
(
0
)
if
overwrite_b
and
b
.
flags
.
f_contiguous
:
b_copy
=
b
else
:
b_copy
=
_copy_to_fortran_order_even_if_1d
(
b
)
numba_gttrs
(
val_to_int_ptr
(
_trans_char_to_int
(
trans
)),
val_to_int_ptr
(
n
),
val_to_int_ptr
(
nrhs
),
dl
.
view
(
w_type
)
.
ctypes
,
d
.
view
(
w_type
)
.
ctypes
,
du
.
view
(
w_type
)
.
ctypes
,
du2
.
view
(
w_type
)
.
ctypes
,
ipiv
.
ctypes
,
b_copy
.
view
(
w_type
)
.
ctypes
,
val_to_int_ptr
(
n
),
info
,
)
return
b_copy
,
int_ptr_to_val
(
info
)
return
impl
def
_gtcon
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
tuple
[
ndarray
,
int
]:
"""Placeholder for computing the condition number of a tridiagonal system."""
return
# type: ignore
@overload
(
_gtcon
)
def
gtcon_impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
Callable
[
[
ndarray
,
ndarray
,
ndarray
,
ndarray
,
ndarray
,
float
,
str
],
tuple
[
ndarray
,
int
]
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
dl
,
"gtcon"
)
_check_scipy_linalg_matrix
(
d
,
"gtcon"
)
_check_scipy_linalg_matrix
(
du
,
"gtcon"
)
_check_scipy_linalg_matrix
(
du2
,
"gtcon"
)
dtype
=
d
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_gtcon
=
_LAPACK
()
.
numba_xgtcon
(
dtype
)
def
impl
(
dl
:
ndarray
,
d
:
ndarray
,
du
:
ndarray
,
du2
:
ndarray
,
ipiv
:
ndarray
,
anorm
:
float
,
norm
:
str
,
)
->
tuple
[
ndarray
,
int
]:
n
=
np
.
int32
(
d
.
shape
[
-
1
])
rcond
=
np
.
empty
(
1
,
dtype
=
dtype
)
work
=
np
.
empty
(
2
*
n
,
dtype
=
dtype
)
iwork
=
np
.
empty
(
n
,
dtype
=
np
.
int32
)
info
=
val_to_int_ptr
(
0
)
numba_gtcon
(
val_to_int_ptr
(
ord
(
norm
)),
val_to_int_ptr
(
n
),
dl
.
view
(
w_type
)
.
ctypes
,
d
.
view
(
w_type
)
.
ctypes
,
du
.
view
(
w_type
)
.
ctypes
,
du2
.
view
(
w_type
)
.
ctypes
,
ipiv
.
ctypes
,
np
.
array
(
anorm
,
dtype
=
dtype
)
.
view
(
w_type
)
.
ctypes
,
rcond
.
view
(
w_type
)
.
ctypes
,
work
.
view
(
w_type
)
.
ctypes
,
iwork
.
ctypes
,
info
,
)
return
rcond
,
int_ptr_to_val
(
info
)
return
impl
def
_solve_tridiagonal
(
a
:
ndarray
,
b
:
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
):
"""
Solve a positive-definite linear system using the Cholesky decomposition.
"""
return
linalg
.
solve
(
a
=
a
,
b
=
b
,
lower
=
lower
,
overwrite_a
=
overwrite_a
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
,
transposed
=
transposed
,
assume_a
=
"tridiagonal"
,
)
@overload
(
_solve_tridiagonal
)
def
_tridiagonal_solve_impl
(
A
:
ndarray
,
B
:
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
Callable
[[
ndarray
,
ndarray
,
bool
,
bool
,
bool
,
bool
,
bool
],
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"solve"
)
_check_scipy_linalg_matrix
(
B
,
"solve"
)
def
impl
(
A
:
ndarray
,
B
:
ndarray
,
lower
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
transposed
:
bool
,
)
->
ndarray
:
n
=
np
.
int32
(
A
.
shape
[
-
1
])
_solve_check_input_shapes
(
A
,
B
)
norm
=
"1"
if
transposed
:
A
=
A
.
T
dl
,
d
,
du
=
np
.
diag
(
A
,
-
1
),
np
.
diag
(
A
,
0
),
np
.
diag
(
A
,
1
)
anorm
=
tridiagonal_norm
(
du
,
d
,
dl
)
dl
,
d
,
du
,
du2
,
IPIV
,
INFO
=
_gttrf
(
dl
,
d
,
du
)
_solve_check
(
n
,
INFO
)
X
,
INFO
=
_gttrs
(
dl
,
d
,
du
,
du2
,
IPIV
,
B
,
trans
=
transposed
,
overwrite_b
=
overwrite_b
)
_solve_check
(
n
,
INFO
)
RCOND
,
INFO
=
_gtcon
(
dl
,
d
,
du
,
du2
,
IPIV
,
anorm
,
norm
)
_solve_check
(
n
,
INFO
,
True
,
RCOND
)
return
X
return
impl
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
a038c8ee
...
@@ -9,6 +9,7 @@ from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
...
@@ -9,6 +9,7 @@ from pytensor.link.numba.dispatch.linalg.solve.general import _solve_gen
from
pytensor.link.numba.dispatch.linalg.solve.posdef
import
_solve_psd
from
pytensor.link.numba.dispatch.linalg.solve.posdef
import
_solve_psd
from
pytensor.link.numba.dispatch.linalg.solve.symmetric
import
_solve_symmetric
from
pytensor.link.numba.dispatch.linalg.solve.symmetric
import
_solve_symmetric
from
pytensor.link.numba.dispatch.linalg.solve.triangular
import
_solve_triangular
from
pytensor.link.numba.dispatch.linalg.solve.triangular
import
_solve_triangular
from
pytensor.link.numba.dispatch.linalg.solve.tridiagonal
import
_solve_tridiagonal
from
pytensor.tensor.slinalg
import
(
from
pytensor.tensor.slinalg
import
(
BlockDiagonal
,
BlockDiagonal
,
Cholesky
,
Cholesky
,
...
@@ -114,10 +115,12 @@ def numba_funcify_Solve(op, node, **kwargs):
...
@@ -114,10 +115,12 @@ def numba_funcify_Solve(op, node, **kwargs):
solve_fn
=
_solve_symmetric
solve_fn
=
_solve_symmetric
elif
assume_a
==
"pos"
:
elif
assume_a
==
"pos"
:
solve_fn
=
_solve_psd
solve_fn
=
_solve_psd
elif
assume_a
==
"tridiagonal"
:
solve_fn
=
_solve_tridiagonal
else
:
else
:
warnings
.
warn
(
warnings
.
warn
(
f
"Numba assume_a={assume_a} not implemented. Falling back to general solve.
\n
"
f
"Numba assume_a={assume_a} not implemented. Falling back to general solve.
\n
"
f
"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her',
or 'triangular
' to improve performance."
,
f
"If appropriate, you may want to set assume_a to one of 'sym', 'pos', 'her',
'triangular' or 'tridiagonal
' to improve performance."
,
UserWarning
,
UserWarning
,
)
)
solve_fn
=
_solve_gen
solve_fn
=
_solve_gen
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
a038c8ee
...
@@ -97,7 +97,7 @@ class TestSolves:
...
@@ -97,7 +97,7 @@ class TestSolves:
[(
5
,
1
),
(
5
,
5
),
(
5
,)],
[(
5
,
1
),
(
5
,
5
),
(
5
,)],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
],
)
)
@pytest.mark.parametrize
(
"assume_a"
,
[
"gen"
,
"sym"
,
"pos"
],
ids
=
str
)
@pytest.mark.parametrize
(
"assume_a"
,
[
"gen"
,
"sym"
,
"pos"
,
"tridiagonal"
],
ids
=
str
)
def
test_solve
(
def
test_solve
(
self
,
self
,
b_shape
:
tuple
[
int
],
b_shape
:
tuple
[
int
],
...
@@ -106,7 +106,7 @@ class TestSolves:
...
@@ -106,7 +106,7 @@ class TestSolves:
overwrite_a
:
bool
,
overwrite_a
:
bool
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
):
):
if
assume_a
not
in
(
"sym"
,
"her"
,
"pos"
)
and
not
lower
:
if
assume_a
not
in
(
"sym"
,
"her"
,
"pos"
,
"tridiagonal"
)
and
not
lower
:
# Avoid redundant tests with lower=True and lower=False for non symmetric matrices
# Avoid redundant tests with lower=True and lower=False for non symmetric matrices
pytest
.
skip
(
"Skipping redundant test already covered by lower=True"
)
pytest
.
skip
(
"Skipping redundant test already covered by lower=True"
)
...
@@ -120,6 +120,14 @@ class TestSolves:
...
@@ -120,6 +120,14 @@ class TestSolves:
# We have to set the unused triangle to something other than zero
# We have to set the unused triangle to something other than zero
# to see lapack destroying it.
# to see lapack destroying it.
x
[
np
.
triu_indices
(
n
,
1
)
if
lower
else
np
.
tril_indices
(
n
,
1
)]
=
np
.
pi
x
[
np
.
triu_indices
(
n
,
1
)
if
lower
else
np
.
tril_indices
(
n
,
1
)]
=
np
.
pi
elif
assume_a
==
"tridiagonal"
:
_x
=
x
x
=
np
.
zeros_like
(
x
)
n
=
x
.
shape
[
-
1
]
arange_n
=
np
.
arange
(
n
)
x
[
arange_n
[
1
:],
arange_n
[:
-
1
]]
=
np
.
diag
(
_x
,
k
=-
1
)
x
[
arange_n
,
arange_n
]
=
np
.
diag
(
_x
,
k
=
0
)
x
[
arange_n
[:
-
1
],
arange_n
[
1
:]]
=
np
.
diag
(
_x
,
k
=
1
)
return
x
return
x
A
=
pt
.
matrix
(
"A"
,
dtype
=
floatX
)
A
=
pt
.
matrix
(
"A"
,
dtype
=
floatX
)
...
@@ -146,7 +154,14 @@ class TestSolves:
...
@@ -146,7 +154,14 @@ class TestSolves:
op
=
f
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
op
=
f
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
assert
isinstance
(
op
,
Solve
)
assert
isinstance
(
op
,
Solve
)
assert
op
.
assume_a
==
assume_a
destroy_map
=
op
.
destroy_map
destroy_map
=
op
.
destroy_map
if
overwrite_a
and
assume_a
==
"tridiagonal"
:
# Tridiagonal solve never destroys the A matrix
# Treat test from here as if overwrite_a is False
overwrite_a
=
False
if
overwrite_a
and
overwrite_b
:
if
overwrite_a
and
overwrite_b
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Test not implemented for simultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor"
"Test not implemented for simultaneous overwrite_a and overwrite_b, as that's not currently supported by PyTensor"
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论