提交 9d360389 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: Brandon T. Willard

Add numba implementation for split

上级 9e0434bf
......@@ -15,6 +15,7 @@ from aesara.tensor.basic import (
Join,
MakeVector,
ScalarFromTensor,
Split,
TensorFromScalar,
)
from aesara.tensor.shape import Unbroadcast
......@@ -138,6 +139,18 @@ def numba_funcify_Join(op, **kwargs):
return join
@numba_funcify.register(Split)
def numba_funcify_Split(op, **kwargs):
@numba_basic.numba_njit
def split(tensor, axis, indices):
# Work around for https://github.com/numba/numba/issues/8257
axis = axis % tensor.ndim
axis = numba_basic.to_scalar(axis)
return np.split(tensor, np.cumsum(indices)[:-1], axis=axis)
return split
@numba_funcify.register(ExtractDiag)
def numba_funcify_ExtractDiag(op, **kwargs):
offset = op.offset
......
......@@ -1274,6 +1274,66 @@ def test_Join_view():
)
@pytest.mark.parametrize(
"n_splits, axis, values, sizes",
[
(
0,
0,
set_test_value(at.vector(), rng.normal(size=20).astype(config.floatX)),
set_test_value(at.vector(dtype="int64"), []),
),
(
5,
0,
set_test_value(at.vector(), rng.normal(size=5).astype(config.floatX)),
set_test_value(
at.vector(dtype="int64"), rng.multinomial(5, np.ones(5) / 5)
),
),
(
5,
0,
set_test_value(at.vector(), rng.normal(size=10).astype(config.floatX)),
set_test_value(
at.vector(dtype="int64"), rng.multinomial(10, np.ones(5) / 5)
),
),
(
5,
-1,
set_test_value(at.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
set_test_value(
at.vector(dtype="int64"), rng.multinomial(7, np.ones(5) / 5)
),
),
(
5,
-2,
set_test_value(at.matrix(), rng.normal(size=(11, 7)).astype(config.floatX)),
set_test_value(
at.vector(dtype="int64"), rng.multinomial(11, np.ones(5) / 5)
),
),
],
)
def test_Split(n_splits, axis, values, sizes):
g = at.split(values, sizes, n_splits, axis=axis)
assert len(g) == n_splits
if n_splits == 0:
return
g_fg = FunctionGraph(outputs=[g] if n_splits == 1 else g)
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"val, offset",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论