提交 ef97287b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Improve CAReduce Numba implementation

上级 9e24b10a
......@@ -15,7 +15,7 @@ from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.elemwise import CAReduce, DimShuffle
from pytensor.tensor.math import All, Any, Max, Min, Prod, ProdWithoutZeros, Sum
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
from tests.link.numba.test_basic import (
......@@ -23,7 +23,7 @@ from tests.link.numba.test_basic import (
scalar_my_multi_out,
set_test_value,
)
from tests.tensor.test_elemwise import TestElemwise
from tests.tensor.test_elemwise import TestElemwise, careduce_benchmark_tester
rng = np.random.default_rng(42849)
......@@ -249,12 +249,12 @@ def test_Dimshuffle_non_contiguous():
(
lambda x, axis=None, dtype=None, acc_dtype=None: All(axis)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Any(axis)(x),
0,
set_test_value(pt.vector(), np.arange(3, dtype=config.floatX)),
set_test_value(pt.vector(dtype="bool"), np.array([False, True, False])),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Sum(
......@@ -301,6 +301,24 @@ def test_Dimshuffle_non_contiguous():
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
(), # Empty axes would normally be rewritten away, but we want to test it still works
set_test_value(
pt.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
)(x),
None,
set_test_value(
pt.scalar(), np.array(99.0, dtype=config.floatX)
), # Scalar input would normally be rewritten away, but we want to test it still works
),
(
lambda x, axis=None, dtype=None, acc_dtype=None: Prod(
axis=axis, dtype=dtype, acc_dtype=acc_dtype
......@@ -367,7 +385,7 @@ def test_CAReduce(careduce_fn, axis, v):
g = careduce_fn(v, axis=axis)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
fn, _ = compare_numba_and_py(
g_fg,
[
i.tag.test_value
......@@ -375,6 +393,10 @@ def test_CAReduce(careduce_fn, axis, v):
if not isinstance(i, SharedVariable | Constant)
],
)
# Confirm CAReduce is in the compiled function
fn.dprint()
[node] = fn.maker.fgraph.apply_nodes
assert isinstance(node.op, CAReduce)
def test_scalar_Elemwise_Clip():
......@@ -619,10 +641,10 @@ def test_logsumexp_benchmark(size, axis, benchmark):
X_lse_fn = pytensor.function([X], X_lse, mode="NUMBA")
# JIT compile first
_ = X_lse_fn(X_val)
res = benchmark(X_lse_fn, X_val)
res = X_lse_fn(X_val)
exp_res = scipy.special.logsumexp(X_val, axis=axis, keepdims=True)
np.testing.assert_array_almost_equal(res, exp_res)
benchmark(X_lse_fn, X_val)
def test_fused_elemwise_benchmark(benchmark):
......@@ -653,3 +675,19 @@ def test_elemwise_out_type():
x_val = np.broadcast_to(np.zeros((3,)), (6, 3))
assert func(x_val).shape == (18,)
@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_numba_careduce_benchmark(axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="NUMBA", benchmark=benchmark
)
......@@ -983,27 +983,33 @@ class TestVectorize:
assert vect_node.inputs[0] is bool_tns
@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_careduce_benchmark(axis, c_contiguous, benchmark):
def careduce_benchmark_tester(axis, c_contiguous, mode, benchmark):
N = 256
x_test = np.random.uniform(size=(N, N, N))
transpose_axis = (0, 1, 2) if c_contiguous else (2, 0, 1)
x = pytensor.shared(x_test, name="x", shape=x_test.shape)
out = x.transpose(transpose_axis).sum(axis=axis)
fn = pytensor.function([], out)
fn = pytensor.function([], out, mode=mode)
np.testing.assert_allclose(
fn(),
x_test.transpose(transpose_axis).sum(axis=axis),
)
benchmark(fn)
@pytest.mark.parametrize(
"axis",
(0, 1, 2, (0, 1), (0, 2), (1, 2), None),
ids=lambda x: f"axis={x}",
)
@pytest.mark.parametrize(
"c_contiguous",
(True, False),
ids=lambda x: f"c_contiguous={x}",
)
def test_c_careduce_benchmark(axis, c_contiguous, benchmark):
return careduce_benchmark_tester(
axis, c_contiguous, mode="FAST_RUN", benchmark=benchmark
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论