提交 c72a48d7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move dot Op dispatchers to Elemwise

They are actually defined in tensor/math.py, but this is better than being in `basic.py`
上级 78f4d2da
......@@ -24,8 +24,6 @@ from pytensor.link.utils import (
)
from pytensor.scalar.basic import ScalarType
from pytensor.sparse import SparseTensorType
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.type import TensorType
......@@ -364,71 +362,6 @@ def int_to_float_fn(inputs, out_dtype):
return inputs_cast
@numba_funcify.register(Dot)
def numba_funcify_Dot(op, node, **kwargs):
# Numba's `np.dot` does not support integer dtypes, so we need to cast to float.
x, y = node.inputs
[out] = node.outputs
x_dtype = x.type.dtype
y_dtype = y.type.dtype
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
out_dtype = out.type.dtype
if x_dtype == dot_dtype and y_dtype == dot_dtype:
@numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y))
elif x_dtype == dot_dtype and y_dtype != dot_dtype:
@numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y.astype(dot_dtype)))
elif x_dtype != dot_dtype and y_dtype == dot_dtype:
@numba_njit
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y))
else:
@numba_njit()
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
if out_dtype == dot_dtype:
return dot
else:
@numba_njit
def dot_with_cast(x, y):
return dot(x, y).astype(out_dtype)
return dot_with_cast
@numba_funcify.register(BatchedDot)
def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype
@numba_njit
def batched_dot(x, y):
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]):
z0[i] = np.dot(x[i], y[i])
return z0
return batched_dot
@numba_funcify.register(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
......
......@@ -35,8 +35,9 @@ from pytensor.scalar.basic import (
scalar_maximum,
)
from pytensor.scalar.basic import add as add_as
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Argmax, MulWithoutZeros, Sum
from pytensor.tensor.math import Argmax, Dot, MulWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
......@@ -599,3 +600,68 @@ def numba_funcify_Argmax(op, node, **kwargs):
return max_idx_res
return argmax
@numba_funcify.register(Dot)
def numba_funcify_Dot(op, node, **kwargs):
# Numba's `np.dot` does not support integer dtypes, so we need to cast to float.
x, y = node.inputs
[out] = node.outputs
x_dtype = x.type.dtype
y_dtype = y.type.dtype
dot_dtype = f"float{max((32, out.type.numpy_dtype.itemsize * 8))}"
out_dtype = out.type.dtype
if x_dtype == dot_dtype and y_dtype == dot_dtype:
@numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y))
elif x_dtype == dot_dtype and y_dtype != dot_dtype:
@numba_njit
def dot(x, y):
return np.asarray(np.dot(x, y.astype(dot_dtype)))
elif x_dtype != dot_dtype and y_dtype == dot_dtype:
@numba_njit
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y))
else:
@numba_njit()
def dot(x, y):
return np.asarray(np.dot(x.astype(dot_dtype), y.astype(dot_dtype)))
if out_dtype == dot_dtype:
return dot
else:
@numba_njit
def dot_with_cast(x, y):
return dot(x, y).astype(out_dtype)
return dot_with_cast
@numba_funcify.register(BatchedDot)
def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype
@numba_njit
def batched_dot(x, y):
# Numba does not support 3D matmul
# https://github.com/numba/numba/issues/3804
shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]):
z0[i] = np.dot(x[i], y[i])
return z0
return batched_dot
......@@ -14,7 +14,6 @@ numba = pytest.importorskip("numba")
import pytensor.scalar as ps
import pytensor.tensor as pt
import pytensor.tensor.math as ptm
from pytensor import config, shared
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function
......@@ -29,7 +28,6 @@ from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.linker import NumbaLinker
from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas, tensor
from pytensor.tensor.elemwise import Elemwise
......@@ -407,86 +405,6 @@ def test_perform_type_convert():
compare_numba_and_py([x], out, [x_test_value])
@pytest.mark.parametrize(
"x, y",
[
(
(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
),
(
(pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")),
(pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")),
),
(
(pt.lmatrix(), rng.poisson(size=(3, 2))),
(pt.fvector(), rng.random(size=(2,)).astype("float32")),
),
(
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
),
(
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
),
],
)
def test_Dot(x, y):
x, x_test_value = x
y, y_test_value = y
g = ptm.dot(x, y)
compare_numba_and_py(
[x, y],
[g],
[x_test_value, y_test_value],
)
@pytest.mark.parametrize(
"x, y, exc",
[
(
(
pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
(
pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
None,
),
(
(
pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
(
pt.ltensor3(),
rng.poisson(size=(2, 3, 3)).astype("int64"),
),
None,
),
],
)
def test_BatchedDot(x, y, exc):
x, x_test_value = x
y, y_test_value = y
g = blas.BatchedDot()(x, y)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x, y],
g,
[x_test_value, y_test_value],
)
def test_shared():
a = shared(np.array([1, 2, 3], dtype=config.floatX))
......@@ -716,18 +634,3 @@ def test_function_overhead(mode, benchmark):
assert np.sum(fn(test_x)) == 1000
benchmark(fn, test_x)
@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
def test_mat_vec_dot_performance(dtype, benchmark):
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)
x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype)
out = ptm.dot(A, x)
fn = function([A, x], out, mode="NUMBA", trust_input=True)
rng = np.random.default_rng(948)
A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype)
x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype)
np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4)
benchmark(fn, A_test, x_test)
......@@ -13,6 +13,7 @@ from pytensor.compile import get_mode
from pytensor.compile.ops import deep_copy_op
from pytensor.gradient import grad
from pytensor.scalar import Composite, float64
from pytensor.tensor import blas, tensor
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
......@@ -670,3 +671,98 @@ class TestsBenchmark:
@pytest.mark.parametrize("c_contiguous", (True, False))
def test_dimshuffle(self, c_contiguous, benchmark):
dimshuffle_benchmark("NUMBA", c_contiguous, benchmark)
@pytest.mark.parametrize(
"x, y",
[
(
(pt.matrix(), rng.random(size=(3, 2)).astype(config.floatX)),
(pt.vector(), rng.random(size=(2,)).astype(config.floatX)),
),
(
(pt.matrix(dtype="float64"), rng.random(size=(3, 2)).astype("float64")),
(pt.vector(dtype="float32"), rng.random(size=(2,)).astype("float32")),
),
(
(pt.lmatrix(), rng.poisson(size=(3, 2))),
(pt.fvector(), rng.random(size=(2,)).astype("float32")),
),
(
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
(pt.lvector(), rng.random(size=(2,)).astype(np.int64)),
),
(
(pt.vector(dtype="int16"), rng.random(size=(2,)).astype(np.int16)),
(pt.vector(dtype="uint8"), rng.random(size=(2,)).astype(np.uint8)),
),
],
)
def test_Dot(x, y):
x, x_test_value = x
y, y_test_value = y
g = ptm.dot(x, y)
compare_numba_and_py(
[x, y],
[g],
[x_test_value, y_test_value],
)
@pytest.mark.parametrize(
"x, y, exc",
[
(
(
pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
(
pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
None,
),
(
(
pt.dtensor3(),
rng.random(size=(2, 3, 3)).astype("float64"),
),
(
pt.ltensor3(),
rng.poisson(size=(2, 3, 3)).astype("int64"),
),
None,
),
],
)
def test_BatchedDot(x, y, exc):
x, x_test_value = x
y, y_test_value = y
g = blas.BatchedDot()(x, y)
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
[x, y],
g,
[x_test_value, y_test_value],
)
@pytest.mark.parametrize("dtype", ("float64", "float32", "mixed"))
def test_mat_vec_dot_performance(dtype, benchmark):
A = tensor("A", shape=(512, 512), dtype="float64" if dtype == "mixed" else dtype)
x = tensor("x", shape=(512,), dtype="float32" if dtype == "mixed" else dtype)
out = ptm.dot(A, x)
fn = function([A, x], out, mode="NUMBA", trust_input=True)
rng = np.random.default_rng(948)
A_test = rng.standard_normal(size=A.type.shape, dtype=A.type.dtype)
x_test = rng.standard_normal(size=x.type.shape, dtype=x.type.dtype)
np.testing.assert_allclose(fn(A_test, x_test), np.dot(A_test, x_test), atol=1e-4)
benchmark(fn, A_test, x_test)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论