提交 8bbd2666 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement Numba conversion for ExtractDiag

上级 d3c3e1b1
...@@ -35,6 +35,7 @@ from aesara.tensor.basic import ( ...@@ -35,6 +35,7 @@ from aesara.tensor.basic import (
AllocDiag, AllocDiag,
AllocEmpty, AllocEmpty,
ARange, ARange,
ExtractDiag,
Join, Join,
MakeVector, MakeVector,
Rebroadcast, Rebroadcast,
...@@ -819,3 +820,16 @@ def numba_funcify_Join(op, **kwargs): ...@@ -819,3 +820,16 @@ def numba_funcify_Join(op, **kwargs):
return np.concatenate(tensors, to_scalar(axis)) return np.concatenate(tensors, to_scalar(axis))
return join return join
@numba_funcify.register(ExtractDiag)
def numba_funcify_ExtractDiag(op, **kwargs):
offset = op.offset
# axis1 = op.axis1
# axis2 = op.axis2
@numba.njit
def extract_diag(x):
return np.diag(x, k=offset)
return extract_diag
...@@ -1021,3 +1021,32 @@ def test_Join_view(): ...@@ -1021,3 +1021,32 @@ def test_Join_view():
if not isinstance(i, (SharedVariable, Constant)) if not isinstance(i, (SharedVariable, Constant))
], ],
) )
@pytest.mark.parametrize(
"val, offset",
[
(
set_test_value(
aet.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))
),
0,
),
(
set_test_value(aet.vector(), np.arange(10, dtype=config.floatX)),
0,
),
],
)
def test_ExtractDiag(val, offset):
g = aet.diag(val, offset)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论