提交 ced9939d authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba linalg: Fallback to objmode with complex inputs

上级 f6986e40
...@@ -5,6 +5,7 @@ import numpy as np ...@@ -5,6 +5,7 @@ import numpy as np
from pytensor import config from pytensor import config
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
generate_fallback_impl,
numba_funcify, numba_funcify,
register_funcify_default_op_cache_key, register_funcify_default_op_cache_key,
) )
...@@ -44,12 +45,6 @@ from pytensor.tensor.slinalg import ( ...@@ -44,12 +45,6 @@ from pytensor.tensor.slinalg import (
from pytensor.tensor.type import complex_dtypes, integer_dtypes from pytensor.tensor.type import complex_dtypes, integer_dtypes
_COMPLEX_DTYPE_NOT_SUPPORTED_MSG = (
"Complex dtype for {op} not supported in numba mode. "
"If you need this functionality, please open an issue at: https://github.com/pymc-devs/pytensor"
)
@numba_funcify.register(Cholesky) @numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs): def numba_funcify_Cholesky(op, node, **kwargs):
""" """
...@@ -65,7 +60,7 @@ def numba_funcify_Cholesky(op, node, **kwargs): ...@@ -65,7 +60,7 @@ def numba_funcify_Cholesky(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c": if inp_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) return generate_fallback_impl(op, node=node, **kwargs)
discrete_inp = inp_dtype.kind in "ibu" discrete_inp = inp_dtype.kind in "ibu"
if discrete_inp and config.compiler_verbose: if discrete_inp and config.compiler_verbose:
print("Cholesky requires casting discrete input to float") # noqa: T201 print("Cholesky requires casting discrete input to float") # noqa: T201
...@@ -125,7 +120,7 @@ def pivot_to_permutation(op, node, **kwargs): ...@@ -125,7 +120,7 @@ def pivot_to_permutation(op, node, **kwargs):
def numba_funcify_LU(op, node, **kwargs): def numba_funcify_LU(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c": if inp_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) return generate_fallback_impl(op, node=node, **kwargs)
discrete_inp = inp_dtype.kind in "ibu" discrete_inp = inp_dtype.kind in "ibu"
if discrete_inp and config.compiler_verbose: if discrete_inp and config.compiler_verbose:
print("LU requires casting discrete input to float") # noqa: T201 print("LU requires casting discrete input to float") # noqa: T201
...@@ -192,7 +187,7 @@ def numba_funcify_LU(op, node, **kwargs): ...@@ -192,7 +187,7 @@ def numba_funcify_LU(op, node, **kwargs):
def numba_funcify_LUFactor(op, node, **kwargs): def numba_funcify_LUFactor(op, node, **kwargs):
inp_dtype = node.inputs[0].type.numpy_dtype inp_dtype = node.inputs[0].type.numpy_dtype
if inp_dtype.kind == "c": if inp_dtype.kind == "c":
NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) return generate_fallback_impl(op, node=node, **kwargs)
discrete_inp = inp_dtype.kind in "ibu" discrete_inp = inp_dtype.kind in "ibu"
if discrete_inp and config.compiler_verbose: if discrete_inp and config.compiler_verbose:
print("LUFactor requires casting discrete input to float") # noqa: T201 print("LUFactor requires casting discrete input to float") # noqa: T201
...@@ -252,7 +247,7 @@ def numba_funcify_Solve(op, node, **kwargs): ...@@ -252,7 +247,7 @@ def numba_funcify_Solve(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
if A_dtype.kind == "c" or b_dtype.kind == "c": if A_dtype.kind == "c" or b_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) raise generate_fallback_impl(op, node=node, **kwargs)
must_cast_A = A_dtype != out_dtype must_cast_A = A_dtype != out_dtype
if must_cast_A and config.compiler_verbose: if must_cast_A and config.compiler_verbose:
print("Solve requires casting first input `A`") # noqa: T201 print("Solve requires casting first input `A`") # noqa: T201
...@@ -326,7 +321,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs): ...@@ -326,7 +321,7 @@ def numba_funcify_SolveTriangular(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
if A_dtype.kind == "c" or b_dtype.kind == "c": if A_dtype.kind == "c" or b_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) raise generate_fallback_impl(op, node=node, **kwargs)
must_cast_A = A_dtype != out_dtype must_cast_A = A_dtype != out_dtype
if must_cast_A and config.compiler_verbose: if must_cast_A and config.compiler_verbose:
print("SolveTriangular requires casting first input `A`") # noqa: T201 print("SolveTriangular requires casting first input `A`") # noqa: T201
...@@ -377,7 +372,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs): ...@@ -377,7 +372,7 @@ def numba_funcify_CholeskySolve(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype out_dtype = node.outputs[0].type.numpy_dtype
if c_dtype.kind == "c" or b_dtype.kind == "c": if c_dtype.kind == "c" or b_dtype.kind == "c":
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) raise generate_fallback_impl(op, node=node, **kwargs)
must_cast_c = c_dtype != out_dtype must_cast_c = c_dtype != out_dtype
if must_cast_c and config.compiler_verbose: if must_cast_c and config.compiler_verbose:
print("CholeskySolve requires casting first input `c`") # noqa: T201 print("CholeskySolve requires casting first input `c`") # noqa: T201
...@@ -425,7 +420,7 @@ def numba_funcify_QR(op, node, **kwargs): ...@@ -425,7 +420,7 @@ def numba_funcify_QR(op, node, **kwargs):
dtype = node.inputs[0].dtype dtype = node.inputs[0].dtype
if dtype in complex_dtypes: if dtype in complex_dtypes:
raise NotImplementedError(_COMPLEX_DTYPE_NOT_SUPPORTED_MSG.format(op=op)) return generate_fallback_impl(op, node=node, **kwargs)
integer_input = dtype in integer_dtypes integer_input = dtype in integer_dtypes
in_dtype = config.floatX if integer_input else dtype in_dtype = config.floatX if integer_input else dtype
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论