Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
e98cbbcf
提交
e98cbbcf
authored
3月 30, 2025
作者:
Jesse Grabowski
提交者:
Jesse Grabowski
4月 19, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Numba dispatch for LU ops
上级
679b2f71
隐藏空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
754 行增加
和
114 行删除
+754
-114
basic.py
pytensor/link/numba/dispatch/basic.py
+1
-1
lu.py
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
+206
-0
lu_factor.py
...sor/link/numba/dispatch/linalg/decomposition/lu_factor.py
+86
-0
general.py
pytensor/link/numba/dispatch/linalg/solve/general.py
+2
-112
lu_solve.py
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
+132
-0
slinalg.py
pytensor/link/numba/dispatch/slinalg.py
+100
-0
test_slinalg.py
tests/link/numba/test_slinalg.py
+227
-1
没有找到文件。
pytensor/link/numba/dispatch/basic.py
浏览文件 @
e98cbbcf
...
...
@@ -76,7 +76,7 @@ def numba_njit(*args, fastmath=None, **kwargs):
message
=
(
"(
\x1b\\
[1m)*"
# ansi escape code for bold text
"Cannot cache compiled function "
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve)" '
'"(numba_funcified_fgraph|store_core_outputs|cholesky|solve|solve_triangular|cho_solve
|lu_factor
)" '
"as it uses dynamic globals"
),
category
=
NumbaWarning
,
...
...
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
0 → 100644
浏览文件 @
e98cbbcf
from
collections.abc
import
Callable
from
typing
import
cast
as
typing_cast
import
numpy
as
np
from
numba
import
njit
as
numba_njit
from
numba.core.extending
import
overload
from
numba.np.linalg
import
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.utils
import
_check_scipy_linalg_matrix
@numba_njit
def
_pivot_to_permutation
(
p
,
dtype
):
p_inv
=
np
.
arange
(
len
(
p
))
.
astype
(
dtype
)
for
i
in
range
(
len
(
p
)):
p_inv
[
i
],
p_inv
[
p
[
i
]]
=
p_inv
[
p
[
i
]],
p_inv
[
i
]
return
p_inv
@numba_njit
def
_lu_factor_to_lu
(
a
,
dtype
,
overwrite_a
):
A_copy
,
IPIV
,
INFO
=
_getrf
(
a
,
overwrite_a
=
overwrite_a
)
L
=
np
.
eye
(
A_copy
.
shape
[
-
1
],
dtype
=
dtype
)
L
+=
np
.
tril
(
A_copy
,
k
=-
1
)
U
=
np
.
triu
(
A_copy
)
# Fortran is 1 indexed, so we need to subtract 1 from the IPIV array
IPIV
=
IPIV
-
1
p_inv
=
_pivot_to_permutation
(
IPIV
,
dtype
=
dtype
)
perm
=
np
.
argsort
(
p_inv
)
return
perm
,
L
,
U
def
_lu_1
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
array of row swaps, such that L[perm] @ U = A.
"""
return
typing_cast
(
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
],
linalg
.
lu
(
a
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
),
)
def
_lu_2
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
permuted L matrix, PL = P @ L.
"""
return
typing_cast
(
tuple
[
np
.
ndarray
,
np
.
ndarray
],
linalg
.
lu
(
a
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
),
)
def
_lu_3
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
"""
Thin wrapper around scipy.linalg.lu. Used as an overload target to avoid side-effects on users to import Pytensor.
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
matrix, P @ L @ U = A.
"""
return
typing_cast
(
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
],
linalg
.
lu
(
a
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
),
)
@overload
(
_lu_1
)
def
lu_impl_1
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[
[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
False. Returns a tuple of (perm, L, U), where perm an integer array of row swaps, such that L[perm] @ U = A.
"""
ensure_lapack
()
_check_scipy_linalg_matrix
(
a
,
"lu"
)
dtype
=
a
.
dtype
def
impl
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
perm
,
L
,
U
=
_lu_factor_to_lu
(
a
,
dtype
,
overwrite_a
)
return
perm
,
L
,
U
return
impl
@overload
(
_lu_2
)
def
lu_impl_2
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is False and p_indices is
True. Returns a tuple of (PL, U), where PL is the permuted L matrix, PL = P @ L.
"""
ensure_lapack
()
_check_scipy_linalg_matrix
(
a
,
"lu"
)
dtype
=
a
.
dtype
def
impl
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
perm
,
L
,
U
=
_lu_factor_to_lu
(
a
,
dtype
,
overwrite_a
)
PL
=
L
[
perm
]
return
PL
,
U
return
impl
@overload
(
_lu_3
)
def
lu_impl_3
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
Callable
[
[
np
.
ndarray
,
bool
,
bool
,
bool
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]
]:
"""
Overload scipy.linalg.lu with a numba function. This function is called when permute_l is True and p_indices is
False. Returns a tuple of (P, L, U), such that P @ L @ U = A.
"""
ensure_lapack
()
_check_scipy_linalg_matrix
(
a
,
"lu"
)
dtype
=
a
.
dtype
def
impl
(
a
:
np
.
ndarray
,
permute_l
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
perm
,
L
,
U
=
_lu_factor_to_lu
(
a
,
dtype
,
overwrite_a
)
P
=
np
.
eye
(
a
.
shape
[
-
1
],
dtype
=
dtype
)[
perm
]
return
P
,
L
,
U
return
impl
pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py
0 → 100644
浏览文件 @
e98cbbcf
from
collections.abc
import
Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
)
def
_getrf
(
A
,
overwrite_a
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
returns an info code with diagnostic information.
"""
(
getrf
,)
=
linalg
.
get_lapack_funcs
(
"getrf"
,
(
A
,))
A_copy
,
ipiv
,
info
=
getrf
(
A
,
overwrite_a
=
overwrite_a
)
return
A_copy
,
ipiv
,
info
@overload
(
_getrf
)
def
getrf_impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
Callable
[[
np
.
ndarray
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"getrf"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrf
=
_LAPACK
()
.
numba_xgetrf
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
if
overwrite_a
and
A
.
flags
.
f_contiguous
:
A_copy
=
A
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
IPIV
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
# type: ignore
INFO
=
val_to_int_ptr
(
0
)
numba_getrf
(
M
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
INFO
)
return
A_copy
,
IPIV
,
int_ptr_to_val
(
INFO
)
return
impl
def
_lu_factor
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
):
"""
Thin wrapper around scipy.linalg.lu_factor. Used as an overload target to avoid side-effects on users who import
Pytensor.
"""
return
linalg
.
lu_factor
(
A
,
overwrite_a
=
overwrite_a
)
@overload
(
_lu_factor
)
def
lu_factor_impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
Callable
[[
np
.
ndarray
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"lu_factor"
)
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
A_copy
,
IPIV
,
INFO
=
_getrf
(
A
,
overwrite_a
=
overwrite_a
)
IPIV
-=
1
# LAPACK uses 1-based indexing, convert to 0-based
if
INFO
!=
0
:
raise
np
.
linalg
.
LinAlgError
(
"LU decomposition failed"
)
return
A_copy
,
IPIV
return
impl
pytensor/link/numba/dispatch/linalg/solve/general.py
浏览文件 @
e98cbbcf
...
...
@@ -11,13 +11,13 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_getrf
from
pytensor.link.numba.dispatch.linalg.solve.lu_solve
import
_getrs
from
pytensor.link.numba.dispatch.linalg.solve.norm
import
_xlange
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
)
...
...
@@ -72,116 +72,6 @@ def xgecon_impl(
return
impl
def
_getrf
(
A
,
overwrite_a
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
"""
Placeholder for LU factorization; used by linalg.solve.
# TODO: Implement an LU_factor Op, then dispatch to this function in numba mode.
"""
return
# type: ignore
@overload
(
_getrf
)
def
getrf_impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
Callable
[[
np
.
ndarray
,
bool
],
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
A
,
"getrf"
)
dtype
=
A
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrf
=
_LAPACK
()
.
numba_xgetrf
(
dtype
)
def
impl
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
=
False
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
]:
_M
,
_N
=
np
.
int32
(
A
.
shape
[
-
2
:])
# type: ignore
if
overwrite_a
and
A
.
flags
.
f_contiguous
:
A_copy
=
A
else
:
A_copy
=
_copy_to_fortran_order
(
A
)
M
=
val_to_int_ptr
(
_M
)
# type: ignore
N
=
val_to_int_ptr
(
_N
)
# type: ignore
LDA
=
val_to_int_ptr
(
_M
)
# type: ignore
IPIV
=
np
.
empty
(
_N
,
dtype
=
np
.
int32
)
# type: ignore
INFO
=
val_to_int_ptr
(
0
)
numba_getrf
(
M
,
N
,
A_copy
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
INFO
)
return
A_copy
,
IPIV
,
int_ptr_to_val
(
INFO
)
return
impl
def
_getrs
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored; used by linalg.solve.
# TODO: Implement an LU_solve Op, then dispatch to this function in numba mode.
"""
return
# type: ignore
@overload
(
_getrs
)
def
getrs_impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
int
,
bool
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
LU
,
"getrs"
)
_check_scipy_linalg_matrix
(
B
,
"getrs"
)
dtype
=
LU
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrs
=
_LAPACK
()
.
numba_xgetrs
(
dtype
)
def
impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
LU
.
shape
[
-
1
])
_solve_check_input_shapes
(
LU
,
B
)
B_is_1d
=
B
.
ndim
==
1
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
NRHS
=
1
if
B_is_1d
else
int
(
B_copy
.
shape
[
-
1
])
TRANS
=
val_to_int_ptr
(
_trans_char_to_int
(
trans
))
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
IPIV
=
_copy_to_fortran_order
(
IPIV
)
INFO
=
val_to_int_ptr
(
0
)
numba_getrs
(
TRANS
,
N
,
NRHS
,
LU
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
if
B_is_1d
:
B_copy
=
B_copy
[
...
,
0
]
return
B_copy
,
int_ptr_to_val
(
INFO
)
return
impl
def
_solve_gen
(
A
:
np
.
ndarray
,
B
:
np
.
ndarray
,
...
...
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
0 → 100644
浏览文件 @
e98cbbcf
from
collections.abc
import
Callable
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
scipy
import
linalg
from
pytensor.link.numba.dispatch.linalg._LAPACK
import
(
_LAPACK
,
_get_underlying_float
,
int_ptr_to_val
,
val_to_int_ptr
,
)
from
pytensor.link.numba.dispatch.linalg.solve.utils
import
_solve_check_input_shapes
from
pytensor.link.numba.dispatch.linalg.utils
import
(
_check_scipy_linalg_matrix
,
_copy_to_fortran_order_even_if_1d
,
_solve_check
,
_trans_char_to_int
,
)
def
_getrs
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve.
"""
return
# type: ignore
@overload
(
_getrs
)
def
getrs_impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
int
,
bool
],
tuple
[
np
.
ndarray
,
int
]]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
LU
,
"getrs"
)
_check_scipy_linalg_matrix
(
B
,
"getrs"
)
dtype
=
LU
.
dtype
w_type
=
_get_underlying_float
(
dtype
)
numba_getrs
=
_LAPACK
()
.
numba_xgetrs
(
dtype
)
def
impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
LU
.
shape
[
-
1
])
_solve_check_input_shapes
(
LU
,
B
)
B_is_1d
=
B
.
ndim
==
1
if
overwrite_b
and
B
.
flags
.
f_contiguous
:
B_copy
=
B
else
:
B_copy
=
_copy_to_fortran_order_even_if_1d
(
B
)
if
B_is_1d
:
B_copy
=
np
.
expand_dims
(
B_copy
,
-
1
)
NRHS
=
1
if
B_is_1d
else
int
(
B_copy
.
shape
[
-
1
])
TRANS
=
val_to_int_ptr
(
_trans_char_to_int
(
trans
))
N
=
val_to_int_ptr
(
_N
)
NRHS
=
val_to_int_ptr
(
NRHS
)
LDA
=
val_to_int_ptr
(
_N
)
LDB
=
val_to_int_ptr
(
_N
)
IPIV
=
_copy_to_fortran_order
(
IPIV
)
INFO
=
val_to_int_ptr
(
0
)
numba_getrs
(
TRANS
,
N
,
NRHS
,
LU
.
view
(
w_type
)
.
ctypes
,
LDA
,
IPIV
.
ctypes
,
B_copy
.
view
(
w_type
)
.
ctypes
,
LDB
,
INFO
,
)
if
B_is_1d
:
B_copy
=
B_copy
[
...
,
0
]
return
B_copy
,
int_ptr_to_val
(
INFO
)
return
impl
def
_lu_solve
(
lu_and_piv
:
tuple
[
np
.
ndarray
,
np
.
ndarray
],
b
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
,
check_finite
:
bool
,
):
"""
Thin wrapper around scipy.lu_solve, used to avoid side effects from numba overloads on users who import Pytensor.
"""
return
linalg
.
lu_solve
(
lu_and_piv
,
b
,
trans
=
trans
,
overwrite_b
=
overwrite_b
,
check_finite
=
check_finite
)
@overload
(
_lu_solve
)
def
lu_solve_impl
(
lu_and_piv
:
tuple
[
np
.
ndarray
,
np
.
ndarray
],
b
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
,
check_finite
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
_check_scipy_linalg_matrix
(
lu_and_piv
[
0
],
"lu_solve"
)
_check_scipy_linalg_matrix
(
b
,
"lu_solve"
)
def
impl
(
lu
:
np
.
ndarray
,
piv
:
np
.
ndarray
,
b
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
,
check_finite
:
bool
,
)
->
np
.
ndarray
:
n
=
np
.
int32
(
lu
.
shape
[
0
])
X
,
INFO
=
_getrs
(
LU
=
lu
,
B
=
b
,
IPIV
=
piv
,
trans
=
trans
,
overwrite_b
=
overwrite_b
)
_solve_check
(
n
,
INFO
)
return
X
return
impl
pytensor/link/numba/dispatch/slinalg.py
浏览文件 @
e98cbbcf
...
...
@@ -4,6 +4,13 @@ import numpy as np
from
pytensor.link.numba.dispatch.basic
import
numba_funcify
,
numba_njit
from
pytensor.link.numba.dispatch.linalg.decomposition.cholesky
import
_cholesky
from
pytensor.link.numba.dispatch.linalg.decomposition.lu
import
(
_lu_1
,
_lu_2
,
_lu_3
,
_pivot_to_permutation
,
)
from
pytensor.link.numba.dispatch.linalg.decomposition.lu_factor
import
_lu_factor
from
pytensor.link.numba.dispatch.linalg.solve.cholesky
import
_cho_solve
from
pytensor.link.numba.dispatch.linalg.solve.general
import
_solve_gen
from
pytensor.link.numba.dispatch.linalg.solve.posdef
import
_solve_psd
...
...
@@ -11,9 +18,12 @@ from pytensor.link.numba.dispatch.linalg.solve.symmetric import _solve_symmetric
from
pytensor.link.numba.dispatch.linalg.solve.triangular
import
_solve_triangular
from
pytensor.link.numba.dispatch.linalg.solve.tridiagonal
import
_solve_tridiagonal
from
pytensor.tensor.slinalg
import
(
LU
,
BlockDiagonal
,
Cholesky
,
CholeskySolve
,
LUFactor
,
PivotToPermutations
,
Solve
,
SolveTriangular
,
)
...
...
@@ -70,6 +80,96 @@ def numba_funcify_Cholesky(op, node, **kwargs):
return
cholesky
@numba_funcify.register
(
PivotToPermutations
)
def
pivot_to_permutation
(
op
,
node
,
**
kwargs
):
inverse
=
op
.
inverse
dtype
=
node
.
inputs
[
0
]
.
dtype
@numba_njit
def
numba_pivot_to_permutation
(
piv
):
p_inv
=
_pivot_to_permutation
(
piv
,
dtype
)
if
inverse
:
return
p_inv
return
np
.
argsort
(
p_inv
)
return
numba_pivot_to_permutation
@numba_funcify.register
(
LU
)
def
numba_funcify_LU
(
op
,
node
,
**
kwargs
):
permute_l
=
op
.
permute_l
check_finite
=
op
.
check_finite
p_indices
=
op
.
p_indices
overwrite_a
=
op
.
overwrite_a
dtype
=
node
.
inputs
[
0
]
.
dtype
if
dtype
in
complex_dtypes
:
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_njit
(
inline
=
"always"
)
def
lu
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to lu"
)
if
p_indices
:
res
=
_lu_1
(
a
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
elif
permute_l
:
res
=
_lu_2
(
a
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
else
:
res
=
_lu_3
(
a
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
)
return
res
return
lu
@numba_funcify.register
(
LUFactor
)
def
numba_funcify_LUFactor
(
op
,
node
,
**
kwargs
):
dtype
=
node
.
inputs
[
0
]
.
dtype
check_finite
=
op
.
check_finite
overwrite_a
=
op
.
overwrite_a
if
dtype
in
complex_dtypes
:
NotImplementedError
(
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG
.
format
(
op
=
op
))
@numba_njit
def
lu_factor
(
a
):
if
check_finite
:
if
np
.
any
(
np
.
bitwise_or
(
np
.
isinf
(
a
),
np
.
isnan
(
a
))):
raise
np
.
linalg
.
LinAlgError
(
"Non-numeric values (nan or inf) found in input to cholesky"
)
LU
,
piv
=
_lu_factor
(
a
,
overwrite_a
)
return
LU
,
piv
return
lu_factor
@numba_funcify.register
(
BlockDiagonal
)
def
numba_funcify_BlockDiagonal
(
op
,
node
,
**
kwargs
):
dtype
=
node
.
outputs
[
0
]
.
dtype
...
...
tests/link/numba/test_slinalg.py
浏览文件 @
e98cbbcf
...
...
@@ -8,7 +8,14 @@ import scipy
import
pytensor
import
pytensor.tensor
as
pt
from
pytensor
import
In
,
config
from
pytensor.tensor.slinalg
import
Cholesky
,
CholeskySolve
,
Solve
,
SolveTriangular
from
pytensor.tensor.slinalg
import
(
LU
,
Cholesky
,
CholeskySolve
,
LUFactor
,
Solve
,
SolveTriangular
,
)
from
tests.link.numba.test_basic
import
compare_numba_and_py
,
numba_inplace_mode
...
...
@@ -494,3 +501,222 @@ def test_block_diag():
C_val
=
np
.
random
.
normal
(
size
=
(
2
,
2
))
.
astype
(
floatX
)
D_val
=
np
.
random
.
normal
(
size
=
(
4
,
4
))
.
astype
(
floatX
)
compare_numba_and_py
([
A
,
B
,
C
,
D
],
[
X
],
[
A_val
,
B_val
,
C_val
,
D_val
])
@pytest.mark.parametrize
(
"inverse"
,
[
True
,
False
],
ids
=
[
"p_inv"
,
"p"
])
def
test_pivot_to_permutation
(
inverse
):
from
pytensor.tensor.slinalg
import
pivot_to_permutation
rng
=
np
.
random
.
default_rng
(
123
)
A
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
perm_pt
=
pt
.
vector
(
"p"
,
dtype
=
"int32"
)
piv_pt
=
pivot_to_permutation
(
perm_pt
,
inverse
=
inverse
)
f
=
pytensor
.
function
([
perm_pt
],
piv_pt
,
mode
=
"NUMBA"
)
_
,
piv
=
scipy
.
linalg
.
lu_factor
(
A
)
if
inverse
:
p
=
np
.
arange
(
len
(
piv
))
for
i
in
range
(
len
(
piv
)):
p
[
i
],
p
[
piv
[
i
]]
=
p
[
piv
[
i
]],
p
[
i
]
np
.
testing
.
assert_allclose
(
f
(
piv
),
p
)
else
:
p
,
*
_
=
scipy
.
linalg
.
lu
(
A
,
p_indices
=
True
)
np
.
testing
.
assert_allclose
(
f
(
piv
),
p
)
@pytest.mark.parametrize
(
"permute_l, p_indices"
,
[(
True
,
False
),
(
False
,
True
),
(
False
,
False
)],
ids
=
[
"PL"
,
"p_indices"
,
"P"
],
)
@pytest.mark.parametrize
(
"overwrite_a"
,
[
True
,
False
],
ids
=
[
"overwrite_a"
,
"no_overwrite"
]
)
def
test_lu
(
permute_l
,
p_indices
,
overwrite_a
):
shape
=
(
5
,
5
)
rng
=
np
.
random
.
default_rng
()
A
=
pt
.
tensor
(
"A"
,
shape
=
shape
,
dtype
=
config
.
floatX
,
)
A_val
=
rng
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
)
lu_outputs
=
pt
.
linalg
.
lu
(
A
,
permute_l
=
permute_l
,
p_indices
=
p_indices
)
fn
,
res
=
compare_numba_and_py
(
[
In
(
A
,
mutable
=
overwrite_a
)],
lu_outputs
,
[
A_val
],
numba_mode
=
numba_inplace_mode
,
inplace
=
True
,
)
op
=
fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
assert
isinstance
(
op
,
LU
)
destroy_map
=
op
.
destroy_map
if
overwrite_a
and
permute_l
:
assert
destroy_map
==
{
0
:
[
0
]}
elif
overwrite_a
:
assert
destroy_map
==
{
1
:
[
0
]}
else
:
assert
destroy_map
==
{}
# Test F-contiguous input
val_f_contig
=
np
.
copy
(
A_val
,
order
=
"F"
)
res_f_contig
=
fn
(
val_f_contig
)
for
x
,
x_f_contig
in
zip
(
res
,
res_f_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_f_contig
)
# Should always be destroyable
assert
(
A_val
==
val_f_contig
)
.
all
()
==
(
not
overwrite_a
)
# Test C-contiguous input
val_c_contig
=
np
.
copy
(
A_val
,
order
=
"C"
)
res_c_contig
=
fn
(
val_c_contig
)
for
x
,
x_c_contig
in
zip
(
res
,
res_c_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_c_contig
)
# Cannot destroy C-contiguous input
np
.
testing
.
assert_allclose
(
val_c_contig
,
A_val
)
# Test non-contiguous input
val_not_contig
=
np
.
repeat
(
A_val
,
2
,
axis
=
0
)[::
2
]
res_not_contig
=
fn
(
val_not_contig
)
for
x
,
x_not_contig
in
zip
(
res
,
res_not_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_not_contig
)
# Cannot destroy non-contiguous input
np
.
testing
.
assert_allclose
(
val_not_contig
,
A_val
)
@pytest.mark.parametrize
(
"overwrite_a"
,
[
True
,
False
],
ids
=
[
"overwrite_a"
,
"no_overwrite"
]
)
def
test_lu_factor
(
overwrite_a
):
shape
=
(
5
,
5
)
rng
=
np
.
random
.
default_rng
()
A
=
pt
.
tensor
(
"A"
,
shape
=
shape
,
dtype
=
config
.
floatX
)
A_val
=
rng
.
normal
(
size
=
shape
)
.
astype
(
config
.
floatX
)
LU
,
piv
=
pt
.
linalg
.
lu_factor
(
A
)
fn
,
res
=
compare_numba_and_py
(
[
In
(
A
,
mutable
=
overwrite_a
)],
[
LU
,
piv
],
[
A_val
],
numba_mode
=
numba_inplace_mode
,
inplace
=
True
,
)
op
=
fn
.
maker
.
fgraph
.
outputs
[
0
]
.
owner
.
op
assert
isinstance
(
op
,
LUFactor
)
if
overwrite_a
:
assert
op
.
destroy_map
==
{
1
:
[
0
]}
# Test F-contiguous input
val_f_contig
=
np
.
copy
(
A_val
,
order
=
"F"
)
res_f_contig
=
fn
(
val_f_contig
)
for
x
,
x_f_contig
in
zip
(
res
,
res_f_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_f_contig
)
# Should always be destroyable
assert
(
A_val
==
val_f_contig
)
.
all
()
==
(
not
overwrite_a
)
# Test C-contiguous input
val_c_contig
=
np
.
copy
(
A_val
,
order
=
"C"
)
res_c_contig
=
fn
(
val_c_contig
)
for
x
,
x_c_contig
in
zip
(
res
,
res_c_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_c_contig
)
# Cannot destroy C-contiguous input
np
.
testing
.
assert_allclose
(
val_c_contig
,
A_val
)
# Test non-contiguous input
val_not_contig
=
np
.
repeat
(
A_val
,
2
,
axis
=
0
)[::
2
]
res_not_contig
=
fn
(
val_not_contig
)
for
x
,
x_not_contig
in
zip
(
res
,
res_not_contig
,
strict
=
True
):
np
.
testing
.
assert_allclose
(
x
,
x_not_contig
)
# Cannot destroy non-contiguous input
np
.
testing
.
assert_allclose
(
val_not_contig
,
A_val
)
@pytest.mark.parametrize
(
"trans"
,
[
True
,
False
],
ids
=
lambda
x
:
f
"trans = {x}"
)
@pytest.mark.parametrize
(
"overwrite_b"
,
[
False
,
True
],
ids
=
[
"no_overwrite"
,
"overwrite_b"
]
)
@pytest.mark.parametrize
(
"b_func, b_shape"
,
[(
pt
.
matrix
,
(
5
,
1
)),
(
pt
.
matrix
,
(
5
,
5
)),
(
pt
.
vector
,
(
5
,))],
ids
=
[
"b_col_vec"
,
"b_matrix"
,
"b_vec"
],
)
def
test_lu_solve
(
b_func
,
b_shape
:
tuple
[
int
,
...
],
trans
:
bool
,
overwrite_b
:
bool
):
A
=
pt
.
matrix
(
"A"
,
dtype
=
floatX
)
b
=
pt
.
tensor
(
"b"
,
shape
=
b_shape
,
dtype
=
floatX
)
rng
=
np
.
random
.
default_rng
(
418
)
A_val
=
rng
.
normal
(
size
=
(
5
,
5
))
.
astype
(
floatX
)
b_val
=
rng
.
normal
(
size
=
b_shape
)
.
astype
(
floatX
)
lu_and_piv
=
pt
.
linalg
.
lu_factor
(
A
)
X
=
pt
.
linalg
.
lu_solve
(
lu_and_piv
,
b
,
b_ndim
=
len
(
b_shape
),
trans
=
trans
,
)
f
,
res
=
compare_numba_and_py
(
[
A
,
In
(
b
,
mutable
=
overwrite_b
)],
X
,
test_inputs
=
[
A_val
,
b_val
],
inplace
=
True
,
numba_mode
=
numba_inplace_mode
,
eval_obj_mode
=
False
,
)
# Test with F_contiguous inputs
A_val_f_contig
=
np
.
copy
(
A_val
,
order
=
"F"
)
b_val_f_contig
=
np
.
copy
(
b_val
,
order
=
"F"
)
res_f_contig
=
f
(
A_val_f_contig
,
b_val_f_contig
)
np
.
testing
.
assert_allclose
(
res_f_contig
,
res
)
all_equal
=
(
b_val
==
b_val_f_contig
)
.
all
()
should_destroy
=
overwrite_b
and
trans
if
should_destroy
:
assert
not
all_equal
else
:
assert
all_equal
# Test with C_contiguous inputs
A_val_c_contig
=
np
.
copy
(
A_val
,
order
=
"C"
)
b_val_c_contig
=
np
.
copy
(
b_val
,
order
=
"C"
)
res_c_contig
=
f
(
A_val_c_contig
,
b_val_c_contig
)
np
.
testing
.
assert_allclose
(
res_c_contig
,
res
)
np
.
testing
.
assert_allclose
(
A_val_c_contig
,
A_val
)
# b c_contiguous vectors are also f_contiguous and destroyable
assert
not
(
should_destroy
and
b_val_c_contig
.
flags
.
f_contiguous
)
==
np
.
allclose
(
b_val_c_contig
,
b_val
)
# Test with non-contiguous inputs
A_val_not_contig
=
np
.
repeat
(
A_val
,
2
,
axis
=
0
)[::
2
]
b_val_not_contig
=
np
.
repeat
(
b_val
,
2
,
axis
=
0
)[::
2
]
res_not_contig
=
f
(
A_val_not_contig
,
b_val_not_contig
)
np
.
testing
.
assert_allclose
(
res_not_contig
,
res
)
np
.
testing
.
assert_allclose
(
A_val_not_contig
,
A_val
)
# Can never destroy non-contiguous inputs
np
.
testing
.
assert_allclose
(
b_val_not_contig
,
b_val
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论