提交 deea8dd3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Fully support ExtractDiag in numba

上级 2138cd67
......@@ -150,14 +150,44 @@ def numba_funcify_Split(op, **kwargs):
@numba_funcify.register(ExtractDiag)
def numba_funcify_ExtractDiag(op, **kwargs):
offset = op.offset
# axis1 = op.axis1
# axis2 = op.axis2
@numba_basic.numba_njit(inline="always")
def extract_diag(x):
return np.diag(x, k=offset)
def numba_funcify_ExtractDiag(op, node, **kwargs):
view = op.view
axis1, axis2, offset = op.axis1, op.axis2, op.offset
if node.inputs[0].type.ndim == 2:
@numba_basic.numba_njit(inline="always")
def extract_diag(x):
out = np.diag(x, k=offset)
if not view:
out = out.copy()
return out
else:
axis1p1 = axis1 + 1
axis2p1 = axis2 + 1
leading_dims = (slice(None),) * axis1
middle_dims = (slice(None),) * (axis2 - axis1 - 1)
@numba_basic.numba_njit(inline="always")
def extract_diag(x):
if offset >= 0:
diag_len = min(x.shape[axis1], max(0, x.shape[axis2] - offset))
else:
diag_len = min(x.shape[axis2], max(0, x.shape[axis1] + offset))
base_shape = x.shape[:axis1] + x.shape[axis1p1:axis2] + x.shape[axis2p1:]
out_shape = base_shape + (diag_len,)
out = np.empty(out_shape)
for i in range(diag_len):
if offset >= 0:
new_entry = x[leading_dims + (i,) + middle_dims + (i + offset,)]
else:
new_entry = x[leading_dims + (i - offset,) + middle_dims + (i,)]
out[..., i] = new_entry
return out
return extract_diag
......
......@@ -17,6 +17,10 @@ from tests.link.numba.test_basic import (
)
pytest.importorskip("numba")
from pytensor.link.numba.dispatch import numba_funcify
rng = np.random.default_rng(42849)
......@@ -366,6 +370,12 @@ def test_Split_view():
),
0,
),
(
set_test_value(
at.matrix(), np.arange(10 * 10, dtype=config.floatX).reshape((10, 10))
),
-1,
),
(
set_test_value(at.vector(), np.arange(10, dtype=config.floatX)),
0,
......@@ -386,6 +396,23 @@ def test_ExtractDiag(val, offset):
)
@pytest.mark.parametrize("k", range(-5, 4))
@pytest.mark.parametrize(
"axis1, axis2", ((0, 1), (0, 2), (0, 3), (1, 2), (1, 3), (2, 3))
)
@pytest.mark.parametrize("reverse_axis", (False, True))
def test_ExtractDiag_exhaustive(k, axis1, axis2, reverse_axis):
if reverse_axis:
axis1, axis2 = axis2, axis1
x = at.tensor4("x")
x_shape = (2, 3, 4, 5)
x_test = np.arange(np.prod(x_shape)).reshape(x_shape)
out = at.diagonal(x, k, axis1, axis2)
numba_fn = numba_funcify(out.owner.op, out.owner)
np.testing.assert_allclose(numba_fn(x_test), np.diagonal(x_test, k, axis1, axis2))
@pytest.mark.parametrize(
"n, m, k, dtype",
[
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论