提交 3b722cec authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba Dot: Handle complex inputs

上级 3ff76039
...@@ -8,6 +8,7 @@ from numba.core.extending import overload ...@@ -8,6 +8,7 @@ from numba.core.extending import overload
from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple from numpy.lib.array_utils import normalize_axis_index, normalize_axis_tuple
from numpy.lib.stride_tricks import as_strided from numpy.lib.stride_tricks import as_strided
from pytensor import config
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.numba.cache import ( from pytensor.link.numba.cache import (
compile_numba_function_src, compile_numba_function_src,
...@@ -608,37 +609,54 @@ def numba_funcify_Dot(op, node, **kwargs): ...@@ -608,37 +609,54 @@ def numba_funcify_Dot(op, node, **kwargs):
x, y = node.inputs x, y = node.inputs
[out] = node.outputs [out] = node.outputs
x_dtype = x.type.dtype x_dtype = x.type.numpy_dtype
y_dtype = y.type.dtype y_dtype = y.type.numpy_dtype
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
out_dtype = out.type.dtype
if x_dtype == dot_dtype and y_dtype == dot_dtype: numba_dot_dtype = out_dtype = out.type.numpy_dtype
if out_dtype.kind not in "fc":
# Numba alawys returns non-integral outputs, we need to cast to float
numba_dot_dtype = np.dtype(
f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
)
if config.compiler_verbose and not (
x_dtype == y_dtype == out_dtype == numba_dot_dtype
):
print( # noqa: T201
"Numba Dot requires a type casting of inputs and/or output: "
f"{x_dtype=}, {y_dtype=}, {out_dtype=}, {numba_dot_dtype=}"
)
if x_dtype == numba_dot_dtype and y_dtype == numba_dot_dtype:
@numba_basic.numba_njit @numba_basic.numba_njit
def dot(x, y): def dot(x, y):
return np.asarray(np.dot(x, y)) return np.asarray(np.dot(x, y))
elif x_dtype == dot_dtype and y_dtype != dot_dtype: elif x_dtype == numba_dot_dtype and y_dtype != numba_dot_dtype:
@numba_basic.numba_njit @numba_basic.numba_njit
def dot(x, y): def dot(x, y):
return np.asarray(np.dot(x, y.astype(dot_dtype))) return np.asarray(np.dot(x, y.astype(numba_dot_dtype)))
elif x_dtype != dot_dtype and y_dtype == dot_dtype: elif x_dtype != numba_dot_dtype and y_dtype == numba_dot_dtype:
@numba_basic.numba_njit @numba_basic.numba_njit
def dot(x, y): def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y)) return np.asarray(np.dot(x.astype(numba_dot_dtype), y))
else: else:
@numba_basic.numba_njit @numba_basic.numba_njit
def dot(x, y): def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype))) return np.asarray(
np.dot(x.astype(numba_dot_dtype), y.astype(numba_dot_dtype))
)
cache_version = 1
if out_dtype == dot_dtype: if out_dtype == numba_dot_dtype:
return dot return dot, cache_version
else: else:
...@@ -646,7 +664,7 @@ def numba_funcify_Dot(op, node, **kwargs): ...@@ -646,7 +664,7 @@ def numba_funcify_Dot(op, node, **kwargs):
def dot_with_cast(x, y): def dot_with_cast(x, y):
return dot(x, y).astype(out_dtype) return dot(x, y).astype(out_dtype)
return dot_with_cast return dot_with_cast, cache_version
@register_funcify_default_op_cache_key(BatchedDot) @register_funcify_default_op_cache_key(BatchedDot)
......
...@@ -718,6 +718,25 @@ class TestsBenchmark: ...@@ -718,6 +718,25 @@ class TestsBenchmark:
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)), (pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)), (pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
), ),
# Viewing the array with 2 last dimensions as complex128 means
# the first entry will be real part and the second entry the imaginary part
(
(
pt.matrix(dtype="complex128"),
rng.random(size=(5, 4, 2)).view("complex128").squeeze(-1),
),
(
pt.matrix(dtype="complex128"),
rng.random(size=(4, 3, 2)).view("complex128").squeeze(-1),
),
),
(
(pt.matrix(dtype="int64"), rng.random(size=(5, 4)).astype("int64")),
(
pt.matrix(dtype="complex128"),
rng.random(size=(4, 3, 2)).view("complex128").squeeze(-1),
),
),
], ],
) )
def test_Dot(x, y): def test_Dot(x, y):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论