提交 c011572c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Validate axis in AllocDiag and ExtractDiag

上级 ed43f023
......@@ -3408,6 +3408,12 @@ class ExtractDiag(Op):
if self.view:
self.view_map = {0: [0]}
self.offset = offset
if axis1 < 0 or axis2 < 0:
raise NotImplementedError(
"ExtractDiag does not support negative axis. Use pytensor.tensor.diagonal instead."
)
if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same")
self.axis1 = axis1
self.axis2 = axis2
......@@ -3502,6 +3508,8 @@ def diagonal(a, offset=0, axis1=0, axis2=1):
tensor : symbolic tensor
"""
a = as_tensor_variable(a)
axis1, axis2 = normalize_axis_tuple((axis1, axis2), ndim=a.type.ndim)
return ExtractDiag(offset, axis1, axis2)(a)
......@@ -3529,6 +3537,10 @@ class AllocDiag(Op):
the diagonals will be allocated. Defaults to second axis (i.e. 1).
"""
self.offset = offset
if axis1 < 0 or axis2 < 0:
raise NotImplementedError("AllocDiag does not support negative axis")
if axis1 == axis2:
raise ValueError("axis1 and axis2 cannot be the same")
self.axis1 = axis1
self.axis2 = axis2
......
......@@ -3714,6 +3714,14 @@ class TestAllocDiag:
assert np.all(true_grad_input == grad_input)
def test_diagonal_negative_axis():
x = np.arange(2 * 3 * 3).reshape((2, 3, 3))
np.testing.assert_allclose(
at.diagonal(x, axis1=-1, axis2=-2).eval(),
np.diagonal(x, axis1=-1, axis2=-2),
)
def test_transpose():
x1 = dvector("x1")
x2 = dmatrix("x2")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论