提交 3ed908b2 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Allow svd(compute_uv=False) in numba

上级 429ba6c9
...@@ -25,31 +25,20 @@ from pytensor.tensor.nlinalg import ( ...@@ -25,31 +25,20 @@ from pytensor.tensor.nlinalg import (
def numba_funcify_SVD(op, node, **kwargs): def numba_funcify_SVD(op, node, **kwargs):
full_matrices = op.full_matrices full_matrices = op.full_matrices
compute_uv = op.compute_uv compute_uv = op.compute_uv
out_dtype = np.dtype(node.outputs[0].dtype)
if not compute_uv: inputs_cast = int_to_float_fn(node.inputs, out_dtype)
warnings.warn(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type) if not compute_uv:
@numba_basic.numba_njit @numba_basic.numba_njit()
def svd(x): def svd(x):
with numba.objmode(ret=ret_sig): _, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices)
ret = np.linalg.svd(x, full_matrices, compute_uv)
return ret return ret
else: else:
out_dtype = node.outputs[0].type.numpy_dtype @numba_basic.numba_njit()
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_basic.numba_njit(inline="always")
def svd(x): def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices) return np.linalg.svd(inputs_cast(x), full_matrices)
......
...@@ -477,7 +477,7 @@ def test_QRFull(x, mode, exc): ...@@ -477,7 +477,7 @@ def test_QRFull(x, mode, exc):
), ),
True, True,
False, False,
UserWarning, None,
), ),
], ],
) )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论