提交 39704d10 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Add call for issue in not implemented complex lapack routines

上级 8c97bb20
......@@ -24,6 +24,13 @@ from pytensor.tensor.slinalg import (
Solve,
SolveTriangular,
)
from pytensor.tensor.type import complex_dtypes
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
"Complex dtype for {op} not supported in numba mode. "
"If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor"
)
@numba_basic.numba_njit(inline="always")
......@@ -199,9 +206,9 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
b_ndim = op.b_ndim
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
if dtype in complex_dtypes:
raise NotImplementedError(
"Complex inputs not currently supported by solve_triangular in Numba mode"
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op="Solve Triangular")
)
@numba_basic.numba_njit(inline="always")
......@@ -299,10 +306,8 @@ def numba_funcify_Cholesky(op, node, **kwargs):
on_error = op.on_error
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by cholesky in Numba mode"
)
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_basic.numba_njit(inline="always")
def nb_cholesky(a):
......@@ -1089,10 +1094,8 @@ def numba_funcify_Solve(op, node, **kwargs):
transposed = False # TODO: Solve doesnt currently allow the transposed argument
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by solve in Numba mode"
)
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
if assume_a == "gen":
solve_fn = _solve_gen
......@@ -1206,10 +1209,8 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
check_finite = op.check_finite
dtype = node.inputs[0].dtype
if str(dtype).startswith("complex"):
raise NotImplementedError(
"Complex inputs not currently supported by cho_solve in Numba mode"
)
if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op))
@numba_basic.numba_njit(inline="always")
def cho_solve(c, b):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论