Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
d8943500
Unverified
提交
d8943500
authored
9月 05, 2025
作者:
Joren Hammudoglu
提交者:
GitHub
9月 05, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add `scipy-stubs` as development depedency (#1598)
* add `scipy-stubs` as development depedency * fix `scipy-stubs` squigglies
上级
f33ea357
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
87 行增加
和
47 行删除
+87
-47
environment-osx-arm64.yml
environment-osx-arm64.yml
+1
-0
environment.yml
environment.yml
+1
-0
lu.py
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
+10
-19
lu_factor.py
...sor/link/numba/dispatch/linalg/decomposition/lu_factor.py
+8
-2
qr.py
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
+33
-16
lu_solve.py
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
+22
-8
abstract_conv.py
pytensor/tensor/conv/abstract_conv.py
+12
-2
没有找到文件。
environment-osx-arm64.yml
浏览文件 @
d8943500
...
@@ -26,6 +26,7 @@ dependencies:
...
@@ -26,6 +26,7 @@ dependencies:
-
diff-cover
-
diff-cover
-
mypy
-
mypy
-
types-setuptools
-
types-setuptools
-
scipy-stubs
-
pytest
-
pytest
-
pytest-cov
-
pytest-cov
-
pytest-xdist
-
pytest-xdist
...
...
environment.yml
浏览文件 @
d8943500
...
@@ -28,6 +28,7 @@ dependencies:
...
@@ -28,6 +28,7 @@ dependencies:
-
diff-cover
-
diff-cover
-
mypy
-
mypy
-
types-setuptools
-
types-setuptools
-
scipy-stubs
-
pytest
-
pytest
-
pytest-cov
-
pytest-cov
-
pytest-xdist
-
pytest-xdist
...
...
pytensor/link/numba/dispatch/linalg/decomposition/lu.py
浏览文件 @
d8943500
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
cast
as
typing_cast
from
typing
import
Literal
import
numpy
as
np
import
numpy
as
np
from
numba
import
njit
as
numba_njit
from
numba
import
njit
as
numba_njit
...
@@ -37,9 +37,9 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
...
@@ -37,9 +37,9 @@ def _lu_factor_to_lu(a, dtype, overwrite_a):
def
_lu_1
(
def
_lu_1
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
Literal
[
True
]
,
check_finite
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
Literal
[
False
]
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
"""
"""
...
@@ -48,23 +48,20 @@ def _lu_1(
...
@@ -48,23 +48,20 @@ def _lu_1(
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
Called when permute_l is True and p_indices is False, and returns a tuple of (perm, L, U), where perm an integer
array of row swaps, such that L[perm] @ U = A.
array of row swaps, such that L[perm] @ U = A.
"""
"""
return
typing_cast
(
return
linalg
.
lu
(
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
],
linalg
.
lu
(
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
),
)
)
def
_lu_2
(
def
_lu_2
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
Literal
[
False
]
,
check_finite
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
Literal
[
True
]
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
]:
"""
"""
...
@@ -73,23 +70,20 @@ def _lu_2(
...
@@ -73,23 +70,20 @@ def _lu_2(
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
Called when permute_l is False and p_indices is True, and returns a tuple of (PL, U), where PL is the
permuted L matrix, PL = P @ L.
permuted L matrix, PL = P @ L.
"""
"""
return
typing_cast
(
return
linalg
.
lu
(
tuple
[
np
.
ndarray
,
np
.
ndarray
],
linalg
.
lu
(
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
),
)
)
def
_lu_3
(
def
_lu_3
(
a
:
np
.
ndarray
,
a
:
np
.
ndarray
,
permute_l
:
bool
,
permute_l
:
Literal
[
False
]
,
check_finite
:
bool
,
check_finite
:
bool
,
p_indices
:
bool
,
p_indices
:
Literal
[
False
]
,
overwrite_a
:
bool
,
overwrite_a
:
bool
,
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
)
->
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
]:
"""
"""
...
@@ -98,15 +92,12 @@ def _lu_3(
...
@@ -98,15 +92,12 @@ def _lu_3(
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
Called when permute_l is False and p_indices is False, and returns a tuple of (P, L, U), where P is the permutation
matrix, P @ L @ U = A.
matrix, P @ L @ U = A.
"""
"""
return
typing_cast
(
return
linalg
.
lu
(
tuple
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
],
linalg
.
lu
(
a
,
a
,
permute_l
=
permute_l
,
permute_l
=
permute_l
,
check_finite
=
check_finite
,
check_finite
=
check_finite
,
p_indices
=
p_indices
,
p_indices
=
p_indices
,
overwrite_a
=
overwrite_a
,
overwrite_a
=
overwrite_a
,
),
)
)
...
...
pytensor/link/numba/dispatch/linalg/decomposition/lu_factor.py
浏览文件 @
d8943500
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
cast
as
typing_cast
import
numpy
as
np
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.extending
import
overload
...
@@ -21,8 +22,13 @@ def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
...
@@ -21,8 +22,13 @@ def _getrf(A, overwrite_a=False) -> tuple[np.ndarray, np.ndarray, int]:
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
Underlying LAPACK function used for LU factorization. Compared to scipy.linalg.lu_factorize, this function also
returns an info code with diagnostic information.
returns an info code with diagnostic information.
"""
"""
(
getrf
,)
=
linalg
.
get_lapack_funcs
(
"getrf"
,
(
A
,))
funcs
=
linalg
.
get_lapack_funcs
(
"getrf"
,
(
A
,))
A_copy
,
ipiv
,
info
=
getrf
(
A
,
overwrite_a
=
overwrite_a
)
assert
isinstance
(
funcs
,
list
)
# narrows `funcs: list[F] | F` to `funcs: list[F]`
getrf
=
funcs
[
0
]
A_copy
,
ipiv
,
info
=
typing_cast
(
tuple
[
np
.
ndarray
,
np
.
ndarray
,
int
],
getrf
(
A
,
overwrite_a
=
overwrite_a
)
)
return
A_copy
,
ipiv
,
info
return
A_copy
,
ipiv
,
info
...
...
pytensor/link/numba/dispatch/linalg/decomposition/qr.py
浏览文件 @
d8943500
from
typing
import
Literal
import
numpy
as
np
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.extending
import
overload
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
from
numba.np.linalg
import
_copy_to_fortran_order
,
ensure_lapack
...
@@ -13,7 +15,13 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
...
@@ -13,7 +15,13 @@ from pytensor.link.numba.dispatch.linalg._LAPACK import (
def
_xgeqrf
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
def
_xgeqrf
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
"""LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A."""
"""LAPACK geqrf: Computes a QR factorization of a general M-by-N matrix A."""
(
geqrf
,)
=
get_lapack_funcs
((
"geqrf"
,),
(
A
,))
# (geqrf,) = typing_cast(
# list[Callable[..., np.ndarray]], get_lapack_funcs(("geqrf",), (A,))
# )
funcs
=
get_lapack_funcs
((
"geqrf"
,),
(
A
,))
assert
isinstance
(
funcs
,
list
)
# narrows `funcs: list[F] | F` to `funcs: list[F]`
geqrf
=
funcs
[
0
]
return
geqrf
(
A
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
return
geqrf
(
A
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
...
@@ -61,7 +69,10 @@ def xgeqrf_impl(A, overwrite_a, lwork):
...
@@ -61,7 +69,10 @@ def xgeqrf_impl(A, overwrite_a, lwork):
def
_xgeqp3
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
def
_xgeqp3
(
A
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
"""LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A."""
"""LAPACK geqp3: Computes a QR factorization with column pivoting of a general M-by-N matrix A."""
(
geqp3
,)
=
get_lapack_funcs
((
"geqp3"
,),
(
A
,))
funcs
=
get_lapack_funcs
((
"geqp3"
,),
(
A
,))
assert
isinstance
(
funcs
,
list
)
# narrows `funcs: list[F] | F` to `funcs: list[F]`
geqp3
=
funcs
[
0
]
return
geqp3
(
A
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
return
geqp3
(
A
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
...
@@ -111,7 +122,10 @@ def xgeqp3_impl(A, overwrite_a, lwork):
...
@@ -111,7 +122,10 @@ def xgeqp3_impl(A, overwrite_a, lwork):
def
_xorgqr
(
A
:
np
.
ndarray
,
tau
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
def
_xorgqr
(
A
:
np
.
ndarray
,
tau
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
"""LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types)."""
"""LAPACK orgqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (real types)."""
(
orgqr
,)
=
get_lapack_funcs
((
"orgqr"
,),
(
A
,))
funcs
=
get_lapack_funcs
((
"orgqr"
,),
(
A
,))
assert
isinstance
(
funcs
,
list
)
# narrows `funcs: list[F] | F` to `funcs: list[F]`
orgqr
=
funcs
[
0
]
return
orgqr
(
A
,
tau
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
return
orgqr
(
A
,
tau
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
...
@@ -160,7 +174,10 @@ def xorgqr_impl(A, tau, overwrite_a, lwork):
...
@@ -160,7 +174,10 @@ def xorgqr_impl(A, tau, overwrite_a, lwork):
def
_xungqr
(
A
:
np
.
ndarray
,
tau
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
def
_xungqr
(
A
:
np
.
ndarray
,
tau
:
np
.
ndarray
,
overwrite_a
:
bool
,
lwork
:
int
):
"""LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types)."""
"""LAPACK ungqr: Generates the M-by-N matrix Q with orthonormal columns from a QR factorization (complex types)."""
(
ungqr
,)
=
get_lapack_funcs
((
"ungqr"
,),
(
A
,))
funcs
=
get_lapack_funcs
((
"ungqr"
,),
(
A
,))
assert
isinstance
(
funcs
,
list
)
# narrows `funcs: list[F] | F` to `funcs: list[F]`
ungqr
=
funcs
[
0
]
return
ungqr
(
A
,
tau
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
return
ungqr
(
A
,
tau
,
overwrite_a
=
overwrite_a
,
lwork
=
lwork
)
...
@@ -209,8 +226,8 @@ def xungqr_impl(A, tau, overwrite_a, lwork):
...
@@ -209,8 +226,8 @@ def xungqr_impl(A, tau, overwrite_a, lwork):
def
_qr_full_pivot
(
def
_qr_full_pivot
(
x
:
np
.
ndarray
,
x
:
np
.
ndarray
,
mode
:
str
=
"full"
,
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
pivoting
:
bool
=
True
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
...
@@ -234,8 +251,8 @@ def _qr_full_pivot(
...
@@ -234,8 +251,8 @@ def _qr_full_pivot(
def
_qr_full_no_pivot
(
def
_qr_full_no_pivot
(
x
:
np
.
ndarray
,
x
:
np
.
ndarray
,
mode
:
str
=
"full"
,
mode
:
Literal
[
"full"
,
"economic"
]
=
"full"
,
pivoting
:
bool
=
False
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
...
@@ -258,8 +275,8 @@ def _qr_full_no_pivot(
...
@@ -258,8 +275,8 @@ def _qr_full_no_pivot(
def
_qr_r_pivot
(
def
_qr_r_pivot
(
x
:
np
.
ndarray
,
x
:
np
.
ndarray
,
mode
:
str
=
"r"
,
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
pivoting
:
bool
=
True
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
...
@@ -282,8 +299,8 @@ def _qr_r_pivot(
...
@@ -282,8 +299,8 @@ def _qr_r_pivot(
def
_qr_r_no_pivot
(
def
_qr_r_no_pivot
(
x
:
np
.
ndarray
,
x
:
np
.
ndarray
,
mode
:
str
=
"r"
,
mode
:
Literal
[
"r"
,
"raw"
]
=
"r"
,
pivoting
:
bool
=
False
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
...
@@ -306,8 +323,8 @@ def _qr_r_no_pivot(
...
@@ -306,8 +323,8 @@ def _qr_r_no_pivot(
def
_qr_raw_no_pivot
(
def
_qr_raw_no_pivot
(
x
:
np
.
ndarray
,
x
:
np
.
ndarray
,
mode
:
str
=
"raw"
,
mode
:
Literal
[
"raw"
]
=
"raw"
,
pivoting
:
bool
=
False
,
pivoting
:
Literal
[
False
]
=
False
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
...
@@ -332,8 +349,8 @@ def _qr_raw_no_pivot(
...
@@ -332,8 +349,8 @@ def _qr_raw_no_pivot(
def
_qr_raw_pivot
(
def
_qr_raw_pivot
(
x
:
np
.
ndarray
,
x
:
np
.
ndarray
,
mode
:
str
=
"raw"
,
mode
:
Literal
[
"raw"
]
=
"raw"
,
pivoting
:
bool
=
True
,
pivoting
:
Literal
[
True
]
=
True
,
overwrite_a
:
bool
=
False
,
overwrite_a
:
bool
=
False
,
check_finite
:
bool
=
False
,
check_finite
:
bool
=
False
,
lwork
:
int
|
None
=
None
,
lwork
:
int
|
None
=
None
,
...
...
pytensor/link/numba/dispatch/linalg/solve/lu_solve.py
浏览文件 @
d8943500
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
typing
import
Literal
,
TypeAlias
import
numpy
as
np
import
numpy
as
np
from
numba.core.extending
import
overload
from
numba.core.extending
import
overload
...
@@ -20,8 +21,15 @@ from pytensor.link.numba.dispatch.linalg.utils import (
...
@@ -20,8 +21,15 @@ from pytensor.link.numba.dispatch.linalg.utils import (
)
)
_Trans
:
TypeAlias
=
Literal
[
0
,
1
,
2
]
def
_getrs
(
def
_getrs
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
_Trans
|
bool
,
# mypy does not realize that `bool <: Literal[0, 1]`
overwrite_b
:
bool
,
)
->
tuple
[
np
.
ndarray
,
int
]:
)
->
tuple
[
np
.
ndarray
,
int
]:
"""
"""
Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve.
Placeholder for solving a linear system with a matrix that has been LU-factored. Used by linalg.lu_solve.
...
@@ -31,8 +39,10 @@ def _getrs(
...
@@ -31,8 +39,10 @@ def _getrs(
@overload
(
_getrs
)
@overload
(
_getrs
)
def
getrs_impl
(
def
getrs_impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
_Trans
,
overwrite_b
:
bool
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
int
,
bool
],
tuple
[
np
.
ndarray
,
int
]]:
)
->
Callable
[
[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
],
tuple
[
np
.
ndarray
,
int
]
]:
ensure_lapack
()
ensure_lapack
()
_check_scipy_linalg_matrix
(
LU
,
"getrs"
)
_check_scipy_linalg_matrix
(
LU
,
"getrs"
)
_check_scipy_linalg_matrix
(
B
,
"getrs"
)
_check_scipy_linalg_matrix
(
B
,
"getrs"
)
...
@@ -41,7 +51,11 @@ def getrs_impl(
...
@@ -41,7 +51,11 @@ def getrs_impl(
numba_getrs
=
_LAPACK
()
.
numba_xgetrs
(
dtype
)
numba_getrs
=
_LAPACK
()
.
numba_xgetrs
(
dtype
)
def
impl
(
def
impl
(
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
int
,
overwrite_b
:
bool
LU
:
np
.
ndarray
,
B
:
np
.
ndarray
,
IPIV
:
np
.
ndarray
,
trans
:
_Trans
,
overwrite_b
:
bool
,
)
->
tuple
[
np
.
ndarray
,
int
]:
)
->
tuple
[
np
.
ndarray
,
int
]:
_N
=
np
.
int32
(
LU
.
shape
[
-
1
])
_N
=
np
.
int32
(
LU
.
shape
[
-
1
])
_solve_check_input_shapes
(
LU
,
B
)
_solve_check_input_shapes
(
LU
,
B
)
...
@@ -89,7 +103,7 @@ def getrs_impl(
...
@@ -89,7 +103,7 @@ def getrs_impl(
def
_lu_solve
(
def
_lu_solve
(
lu_and_piv
:
tuple
[
np
.
ndarray
,
np
.
ndarray
],
lu_and_piv
:
tuple
[
np
.
ndarray
,
np
.
ndarray
],
b
:
np
.
ndarray
,
b
:
np
.
ndarray
,
trans
:
int
,
trans
:
_Trans
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
check_finite
:
bool
,
):
):
...
@@ -105,10 +119,10 @@ def _lu_solve(
...
@@ -105,10 +119,10 @@ def _lu_solve(
def
lu_solve_impl
(
def
lu_solve_impl
(
lu_and_piv
:
tuple
[
np
.
ndarray
,
np
.
ndarray
],
lu_and_piv
:
tuple
[
np
.
ndarray
,
np
.
ndarray
],
b
:
np
.
ndarray
,
b
:
np
.
ndarray
,
trans
:
int
,
trans
:
_Trans
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
check_finite
:
bool
,
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
bool
,
bool
,
bool
],
np
.
ndarray
]:
)
->
Callable
[[
np
.
ndarray
,
np
.
ndarray
,
np
.
ndarray
,
_Trans
,
bool
,
bool
],
np
.
ndarray
]:
ensure_lapack
()
ensure_lapack
()
_check_scipy_linalg_matrix
(
lu_and_piv
[
0
],
"lu_solve"
)
_check_scipy_linalg_matrix
(
lu_and_piv
[
0
],
"lu_solve"
)
_check_scipy_linalg_matrix
(
b
,
"lu_solve"
)
_check_scipy_linalg_matrix
(
b
,
"lu_solve"
)
...
@@ -117,7 +131,7 @@ def lu_solve_impl(
...
@@ -117,7 +131,7 @@ def lu_solve_impl(
lu
:
np
.
ndarray
,
lu
:
np
.
ndarray
,
piv
:
np
.
ndarray
,
piv
:
np
.
ndarray
,
b
:
np
.
ndarray
,
b
:
np
.
ndarray
,
trans
:
int
,
trans
:
_Trans
,
overwrite_b
:
bool
,
overwrite_b
:
bool
,
check_finite
:
bool
,
check_finite
:
bool
,
)
->
np
.
ndarray
:
)
->
np
.
ndarray
:
...
...
pytensor/tensor/conv/abstract_conv.py
浏览文件 @
d8943500
...
@@ -6,15 +6,25 @@ import logging
...
@@ -6,15 +6,25 @@ import logging
import
sys
import
sys
import
warnings
import
warnings
from
math
import
gcd
from
math
import
gcd
from
typing
import
TYPE_CHECKING
import
numpy
as
np
import
numpy
as
np
from
numpy.exceptions
import
ComplexWarning
from
numpy.exceptions
import
ComplexWarning
try
:
if
TYPE_CHECKING
:
# https://github.com/scipy/scipy-stubs/issues/851
from
scipy.signal._signaltools
import
(
# type: ignore[attr-defined]
_bvalfromboundary
,
_valfrommode
,
convolve
,
)
from
scipy.signal._sigtools
import
_convolve2d
else
:
try
:
from
scipy.signal.signaltools
import
_bvalfromboundary
,
_valfrommode
,
convolve
from
scipy.signal.signaltools
import
_bvalfromboundary
,
_valfrommode
,
convolve
from
scipy.signal.sigtools
import
_convolve2d
from
scipy.signal.sigtools
import
_convolve2d
except
ImportError
:
except
ImportError
:
from
scipy.signal._signaltools
import
_bvalfromboundary
,
_valfrommode
,
convolve
from
scipy.signal._signaltools
import
_bvalfromboundary
,
_valfrommode
,
convolve
from
scipy.signal._sigtools
import
_convolve2d
from
scipy.signal._sigtools
import
_convolve2d
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论