提交 edb1b205 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba int_to_float: Remove buggy helper

* It did not handle complex values correctly * It increased compile time with the nested function
上级 d4b8260c
......@@ -224,36 +224,6 @@ def direct_cast(typingctx, val, typ):
return sig, codegen
def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
if (
all(inp.type.dtype == out_dtype for inp in inputs)
and np.dtype(out_dtype).kind == "f"
):
@numba_njit(inline="always")
def inputs_cast(x):
return x
elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
else:
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
args_dtype = np.dtype(f"f{args_dtype_sz}")
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
return inputs_cast
@singledispatch
def numba_typify(data, dtype=None, **kwargs):
return data
......
......@@ -4,9 +4,9 @@ import numba
import numpy as np
import pytensor.link.numba.dispatch.basic as numba_basic
from pytensor import config
from pytensor.link.numba.dispatch.basic import (
get_numba_type,
int_to_float_fn,
register_funcify_default_op_cache_key,
)
from pytensor.tensor.nlinalg import (
......@@ -26,65 +26,88 @@ def numba_funcify_SVD(op, node, **kwargs):
compute_uv = op.compute_uv
out_dtype = np.dtype(node.outputs[0].dtype)
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
discrete_input = node.inputs[0].type.numpy_dtype.kind in "ibu"
if discrete_input and config.compiler_verbose:
print("SVD requires casting discrete input to float") # noqa: T201
if not compute_uv:
@numba_basic.numba_njit
def svd(x):
_, ret, _ = np.linalg.svd(inputs_cast(x), full_matrices)
if discrete_input:
x = x.astype(out_dtype)
_, ret, _ = np.linalg.svd(x, full_matrices)
return ret
else:
@numba_basic.numba_njit
def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices)
if discrete_input:
x = x.astype(out_dtype)
return np.linalg.svd(x, full_matrices)
return svd
cache_version = 1
return svd, cache_version
@register_funcify_default_op_cache_key(Det)
def numba_funcify_Det(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
discrete_input = node.inputs[0].type.numpy_dtype.kind in "ibu"
if discrete_input and config.compiler_verbose:
print("Det requires casting discrete input to float") # noqa: T201
@numba_basic.numba_njit
def det(x):
return np.array(np.linalg.det(inputs_cast(x))).astype(out_dtype)
if discrete_input:
x = x.astype(out_dtype)
return np.array(np.linalg.det(x), dtype=out_dtype)
return det
cache_version = 1
return det, cache_version
@register_funcify_default_op_cache_key(SLogDet)
def numba_funcify_SLogDet(op, node, **kwargs):
out_dtype_1 = node.outputs[0].type.numpy_dtype
out_dtype_2 = node.outputs[1].type.numpy_dtype
out_dtype_sign = node.outputs[0].type.numpy_dtype
out_dtype_det = node.outputs[1].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
discrete_input = node.inputs[0].type.numpy_dtype.kind in "ibu"
if discrete_input and config.compiler_verbose:
print("SLogDet requires casting discrete input to float") # noqa: T201
@numba_basic.numba_njit
def slogdet(x):
sign, det = np.linalg.slogdet(inputs_cast(x))
if discrete_input:
x = x.astype(out_dtype_det)
sign, det = np.linalg.slogdet(x)
return (
np.array(sign).astype(out_dtype_1),
np.array(det).astype(out_dtype_2),
np.array(sign, dtype=out_dtype_sign),
np.array(det, dtype=out_dtype_det),
)
return slogdet
cache_version = 1
return slogdet, cache_version
@register_funcify_default_op_cache_key(Eig)
def numba_funcify_Eig(op, node, **kwargs):
w_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, w_dtype)
non_complex_input = node.inputs[0].type.numpy_dtype.kind != "c"
if non_complex_input and config.compiler_verbose:
print("Eig requires casting input to complex") # noqa: T201
@numba_basic.numba_njit
def eig(x):
w, v = np.linalg.eig(inputs_cast(x))
if non_complex_input:
# Even floats are better cast to complex, otherwise numba may raise
# ValueError: eig() argument must not cause a domain change.
x = x.astype(w_dtype)
w, v = np.linalg.eig(x)
return w.astype(w_dtype), v.astype(w_dtype)
cache_version = 1
cache_version = 2
return eig, cache_version
......@@ -125,22 +148,32 @@ def numba_funcify_Eigh(op, node, **kwargs):
@register_funcify_default_op_cache_key(MatrixInverse)
def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
discrete_input = node.inputs[0].type.numpy_dtype.kind in "ibu"
if discrete_input and config.compiler_verbose:
print("MatrixInverse requires casting discrete input to float") # noqa: T201
@numba_basic.numba_njit
def matrix_inverse(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
if discrete_input:
x = x.astype(out_dtype)
return np.linalg.inv(x)
return matrix_inverse
cache_version = 1
return matrix_inverse, cache_version
@register_funcify_default_op_cache_key(MatrixPinv)
def numba_funcify_MatrixPinv(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
discrete_input = node.inputs[0].type.numpy_dtype.kind in "ibu"
if discrete_input and config.compiler_verbose:
print("MatrixPinv requires casting discrete input to float") # noqa: T201
@numba_basic.numba_njit
def matrixpinv(x):
return np.linalg.pinv(inputs_cast(x)).astype(out_dtype)
def matrix_pinv(x):
if discrete_input:
x = x.astype(out_dtype)
return np.linalg.pinv(x)
return matrixpinv
cache_version = 1
return matrix_pinv, cache_version
......@@ -4,7 +4,6 @@ import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor import config
from pytensor.tensor import nlinalg
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -52,23 +51,35 @@ y = np.array(
)
@pytest.mark.parametrize("input_dtype", ["float", "int"])
@pytest.mark.parametrize("input_dtype", ["int64", "float64", "complex128"])
@pytest.mark.parametrize("symmetric", [True, False], ids=["symmetric", "general"])
def test_Eig(input_dtype, symmetric):
x = pt.dmatrix("x")
if input_dtype == "float":
x_val = rng.normal(size=(3, 3)).astype(config.floatX)
x = pt.matrix("x", dtype=input_dtype)
if x.type.numpy_dtype.kind in "fc":
x_val = rng.normal(size=(3, 3)).astype(input_dtype)
else:
x_val = rng.integers(1, 10, size=(3, 3)).astype("int64")
if symmetric:
x_val = x_val + x_val.T
def assert_fn(x, y):
# eig can return equivalent values with some sign flips depending on impl, allow for that
np.testing.assert_allclose(np.abs(x), np.abs(y), strict=True)
g = nlinalg.eig(x)
compare_numba_and_py(
_, [eigen_values, eigen_vectors] = compare_numba_and_py(
graph_inputs=[x],
graph_outputs=g,
test_inputs=[x_val],
assert_fn=assert_fn,
)
# Check eig is correct
np.testing.assert_allclose(
x_val @ eigen_vectors,
eigen_vectors @ np.diag(eigen_values),
atol=1e-7,
rtol=1e-5,
)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论