提交 3ed908b2 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Adrian Seyboldt

Allow svd(compute_uv=False) in numba

上级 429ba6c9
......@@ -25,31 +25,20 @@ from pytensor.tensor.nlinalg import (
def numba_funcify_SVD(op, node, **kwargs):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
out_dtype = np.dtype(node.outputs[0].dtype)
if not compute_uv:
warnings.warn(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning,
)
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
ret_sig = get_numba_type(node.outputs[0].type)
if not compute_uv:
@numba_basic.numba_njit
@numba_basic.numba_njit()
def svd(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.svd(x, full_matrices, compute_uv)
_, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba_basic.numba_njit(inline="always")
@numba_basic.numba_njit()
def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices)
......
......@@ -477,7 +477,7 @@ def test_QRFull(x, mode, exc):
),
True,
False,
UserWarning,
None,
),
],
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论