提交 2138cd67 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Restrict diag input ndim to 1 and 2 like numpy

上级 31d593d8
...@@ -3630,7 +3630,7 @@ def diag(v, k=0): ...@@ -3630,7 +3630,7 @@ def diag(v, k=0):
A helper function for two ops: `ExtractDiag` and A helper function for two ops: `ExtractDiag` and
`AllocDiag`. The name `diag` is meant to keep it consistent `AllocDiag`. The name `diag` is meant to keep it consistent
with numpy. It both accepts tensor vector and tensor matrix. with numpy. It both accepts tensor vector and tensor matrix.
While the passed tensor variable `v` has `v.ndim>=2`, it builds a While the passed tensor variable `v` has `v.ndim==2`, it builds a
`ExtractDiag` instance, and returns a vector with its entries equal to `ExtractDiag` instance, and returns a vector with its entries equal to
`v`'s main diagonal; otherwise if `v.ndim` is `1`, it builds an `AllocDiag` `v`'s main diagonal; otherwise if `v.ndim` is `1`, it builds an `AllocDiag`
instance, and returns a matrix with `v` at its k-th diaogonal. instance, and returns a matrix with `v` at its k-th diaogonal.
...@@ -3651,10 +3651,10 @@ def diag(v, k=0): ...@@ -3651,10 +3651,10 @@ def diag(v, k=0):
if _v.ndim == 1: if _v.ndim == 1:
return AllocDiag(k)(_v) return AllocDiag(k)(_v)
elif _v.ndim >= 2: elif _v.ndim == 2:
return diagonal(_v, offset=k) return diagonal(_v, offset=k)
else: else:
raise ValueError("Number of dimensions of `v` must be greater than one.") raise ValueError("Input must be 1- or 2-d.")
def stacklists(arg): def stacklists(arg):
......
...@@ -3593,17 +3593,18 @@ class TestDiag: ...@@ -3593,17 +3593,18 @@ class TestDiag:
# The right matrix is created # The right matrix is created
assert (r == v).all() assert (r == v).all()
# Test scalar input
xx = scalar()
with pytest.raises(ValueError):
diag(xx)
# Test passing a list # Test passing a list
xx = [[1, 2], [3, 4]] xx = [[1, 2], [3, 4]]
g = diag(xx) g = diag(xx)
f = function([], g) f = function([], g)
assert np.array_equal(f(), np.diag(xx)) assert np.array_equal(f(), np.diag(xx))
@pytest.mark.parametrize("inp", (scalar, tensor3))
def test_diag_invalid_input_ndim(self, inp):
x = inp()
with pytest.raises(ValueError, match="Input must be 1- or 2-d."):
diag(x)
class TestExtractDiag: class TestExtractDiag:
@pytest.mark.parametrize("axis1, axis2", [(0, 1), (1, 0)]) @pytest.mark.parametrize("axis1, axis2", [(0, 1), (1, 0)])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论