提交 5a462e98 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fix slow dot in numba

上级 2d414d41
......@@ -565,18 +565,19 @@ def numba_funcify_SpecifyShape(op, node, **kwargs):
def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
if all(
input.type.numpy_dtype == np.dtype(out_dtype) for input in inputs
) and isinstance(np.dtype(out_dtype), np.floating):
if (
all(inp.type.dtype == out_dtype for inp in inputs)
and np.dtype(out_dtype).kind == "f"
):
@numba_njit
@numba_njit(inline="always")
def inputs_cast(x):
return x
elif any(i.type.numpy_dtype.kind in "ib" for i in inputs):
elif any(i.type.numpy_dtype.kind in "uib" for i in inputs):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
@numba_njit
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
......@@ -584,7 +585,7 @@ def int_to_float_fn(inputs, out_dtype):
args_dtype_sz = max(_arg.type.numpy_dtype.itemsize for _arg in inputs)
args_dtype = np.dtype(f"f{args_dtype_sz}")
@numba_njit
@numba_njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
......@@ -593,17 +594,49 @@ def int_to_float_fn(inputs, out_dtype):
@numba_funcify.register(Dot)
def numba_funcify_Dot(op, node, **kwargs):
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
# float.
# Numba's `np.dot` does not support integer dtypes, so we need to cast to float.
x, y = node.inputs
[out] = node.outputs
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
x_dtype = x.type.dtype
y_dtype = y.type.dtype
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
out_dtype = out.type.dtype
@numba_njit
def dot(x, y):
return np.asarray(np.dot(inputs_cast(x), inputs_cast(y))).astype(out_dtype)
if x_dtype == dot_dtype and y_dtype == dot_dtype:
@numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y))
elif x_dtype == dot_dtype and y_dtype != dot_dtype:
@numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y.astype(dot_dtype)))
elif x_dtype != dot_dtype and y_dtype == dot_dtype:
@numba_njit
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y))
else:
@numba_njit()
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
if out_dtype == dot_dtype:
return dot
else:
@numba_njit
def dot_with_cast(x, y):
return dot(x, y).astype(out_dtype)
return dot
return dot_with_cast
@numba_funcify.register(Solve)
......
......@@ -30,7 +30,7 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas
from pytensor.tensor import blas, tensor
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.sort import ArgSortOp, SortOp
......@@ -603,43 +603,41 @@ def test_perform_type_convert():
@pytest.mark.parametrize(
"x, y, exc",
"x, y",
[
(
(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
None,
),
(
(pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")),
(pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")),
None,
),
(
(pt.lmatrix(), rng.poisson(size=(3, 2))),
(pt.fvector(), rng.random(size=(2,)).astype("float32")),
None,
),
(
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
None,
),
(
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
),
],
)
def test_Dot(x, y, exc):
def test_Dot(x, y):
x, x_test_value = x
y, y_test_value = y
g = ptm.Dot()(x, y)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x, y],
[g],
[x_test_value, y_test_value],
)
compare_numba_and_py(
[x, y],
[g],
[x_test_value, y_test_value],
)
@pytest.mark.parametrize(
......@@ -937,3 +935,18 @@ def test_Nonzero(input_data):
compare_numba_and_py(
graph_inputs=[a], graph_outputs=graph_outputs, test_inputs=[input_data]
)
@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
def test_mat_vec_dot_performance(dtype, benchmark):
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)
x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype)
out = ptm.dot(A, x)
fn = function([A, x], out, mode="NUMBA", trust_input=True)
rng = np.random.default_rng(948)
A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype)
x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype)
np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4)
benchmark(fn, A_test, x_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论