提交 39704d10 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add call for issue in not implemented complex lapack routines

上级 8c97bb20
...@@ -24,6 +24,13 @@ from pytensor.tensor.slinalg import ( ...@@ -24,6 +24,13 @@ from pytensor.tensor.slinalg import (
Solve, Solve,
SolveTriangular, SolveTriangular,
) )
from pytensor.tensor.type import complex_dtypes
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
"Complex dtype for {op} not supported in numba mode. "
"If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor"
)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
...@@ -199,9 +206,9 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -199,9 +206,9 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
b_ndim = op.b_ndim b_ndim = op.b_ndim
dtype = node.inputs[0].dtype dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"): if dtype in complex_dtypes:
raise NotImplementedError( raise NotImplementedError(
"Complex inputs not currently supported by solve_triangular in Numba mode" _COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
) )
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
...@@ -299,10 +306,8 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -299,10 +306,8 @@ def numba_funcify_Cholesky(op, node, **kwargs):
on_error = op.on_error on_error = op.on_error
dtype = node.inputs[0].dtype dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"): if dtype in complex_dtypes:
raise NotImplementedError( raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
"Complex inputs not currently supported by cholesky in Numba mode"
)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def nb_cholesky(a): def nb_cholesky(a):
...@@ -1089,10 +1094,8 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -1089,10 +1094,8 @@ def numba_funcify_Solve(op, node, **kwargs):
transposed = False # TODO: Solve doesnt currently allow the transposed argument transposed = False # TODO: Solve doesnt currently allow the transposed argument
dtype = node.inputs[0].dtype dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"): if dtype in complex_dtypes:
raise NotImplementedError( raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
"Complex inputs not currently supported by solve in Numba mode"
)
if assume_a == "gen": if assume_a == "gen":
solve_fn = _solve_gen solve_fn = _solve_gen
...@@ -1206,10 +1209,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -1206,10 +1209,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
check_finite = op.check_finite check_finite = op.check_finite
dtype = node.inputs[0].dtype dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"): if dtype in complex_dtypes:
raise NotImplementedError( raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
"Complex inputs not currently supported by cho_solve in Numba mode"
)
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def cho_solve(c, b): def cho_solve(c, b):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论