提交 0f785bac authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement Numba conversion for Eye Op

上级 8bbd2666
......@@ -36,6 +36,7 @@ from aesara.tensor.basic import (
AllocEmpty,
ARange,
ExtractDiag,
Eye,
Join,
MakeVector,
Rebroadcast,
......@@ -833,3 +834,15 @@ def numba_funcify_ExtractDiag(op, **kwargs):
return np.diag(x, k=offset)
return extract_diag
@numba_funcify.register(Eye)
def numba_funcify_Eye(op, **kwargs):
dtype = np.dtype(op.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype)
@numba.njit
def eye(N, M, k):
return np.eye(to_scalar(N), to_scalar(M), to_scalar(k), dtype=dtype)
return eye
......@@ -1050,3 +1050,35 @@ def test_ExtractDiag(val, offset):
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"n, m, k, dtype",
[
(set_test_value(aet.lscalar(), np.array(1, dtype=np.int64)), None, 0, None),
(
set_test_value(aet.lscalar(), np.array(1, dtype=np.int64)),
set_test_value(aet.lscalar(), np.array(2, dtype=np.int64)),
0,
"float32",
),
(
set_test_value(aet.lscalar(), np.array(1, dtype=np.int64)),
set_test_value(aet.lscalar(), np.array(2, dtype=np.int64)),
1,
"int64",
),
],
)
def test_Eye(n, m, k, dtype):
g = aet.eye(n, m, k, dtype=dtype)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论