提交 d3c3e1b1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement Numba conversion for Join Op

上级 18d1a7a1
...@@ -35,6 +35,7 @@ from aesara.tensor.basic import ( ...@@ -35,6 +35,7 @@ from aesara.tensor.basic import (
AllocDiag, AllocDiag,
AllocEmpty, AllocEmpty,
ARange, ARange,
Join,
MakeVector, MakeVector,
Rebroadcast, Rebroadcast,
ScalarFromTensor, ScalarFromTensor,
...@@ -801,3 +802,20 @@ def numba_funcify_ARange(op, **kwargs): ...@@ -801,3 +802,20 @@ def numba_funcify_ARange(op, **kwargs):
) )
return arange return arange
@numba_funcify.register(Join)
def numba_funcify_Join(op, **kwargs):
view = op.view
if view != -1:
# TODO: Where (and why) is this `Join.view` even being used? From a
# quick search, the answer appears to be "nowhere", so we should
# probably just remove it.
raise NotImplementedError("The `view` parameter to `Join` is not supported")
@numba.njit
def join(axis, *tensors):
return np.concatenate(tensors, to_scalar(axis))
return join
...@@ -935,3 +935,89 @@ def test_CAReduce(careduce_fn, axis, v): ...@@ -935,3 +935,89 @@ def test_CAReduce(careduce_fn, axis, v):
if not isinstance(i, (SharedVariable, Constant)) if not isinstance(i, (SharedVariable, Constant))
], ],
) )
@pytest.mark.parametrize(
"vals, axis",
[
(
(
set_test_value(
aet.matrix(), np.random.normal(size=(1, 2)).astype(config.floatX)
),
set_test_value(
aet.matrix(), np.random.normal(size=(1, 2)).astype(config.floatX)
),
),
0,
),
(
(
set_test_value(
aet.matrix(), np.random.normal(size=(2, 1)).astype(config.floatX)
),
set_test_value(
aet.matrix(), np.random.normal(size=(3, 1)).astype(config.floatX)
),
),
0,
),
(
(
set_test_value(
aet.matrix(), np.random.normal(size=(1, 2)).astype(config.floatX)
),
set_test_value(
aet.matrix(), np.random.normal(size=(1, 2)).astype(config.floatX)
),
),
1,
),
(
(
set_test_value(
aet.matrix(), np.random.normal(size=(2, 2)).astype(config.floatX)
),
set_test_value(
aet.matrix(), np.random.normal(size=(2, 1)).astype(config.floatX)
),
),
1,
),
],
)
def test_Join(vals, axis):
g = aet.join(axis, *vals)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
def test_Join_view():
vals = (
set_test_value(
aet.matrix(), np.random.normal(size=(2, 2)).astype(config.floatX)
),
set_test_value(
aet.matrix(), np.random.normal(size=(2, 2)).astype(config.floatX)
),
)
g = aetb.Join(view=1)(1, *vals)
g_fg = FunctionGraph(outputs=[g])
with pytest.raises(NotImplementedError):
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论