提交 2f44247c authored 作者: Purna Chandra Mansingh's avatar Purna Chandra Mansingh 提交者: Ricardo Vieira

Made tril and triu work beyond 2D arrays

上级 ccfe2d3d
......@@ -1199,10 +1199,12 @@ def tril(m, k=0):
Lower triangle of an array.
Return a copy of an array with elements above the `k`-th diagonal zeroed.
For arrays with ``ndim`` exceeding 2, `tril` will apply to the final two
axes.
Parameters
----------
m : array_like, shape (M, N)
m : array_like, shape (..., M, N)
Input array.
k : int, optional
Diagonal above which to zero elements. `k = 0` (the default) is the
......@@ -1210,23 +1212,48 @@ def tril(m, k=0):
Returns
-------
array, shape (M, N)
tril : ndarray, shape (..., M, N)
Lower triangle of `m`, of same shape and data-type as `m`.
See Also
--------
triu : Same thing, only for the upper triangle.
Examples
--------
>>> at.tril(np.arange(1,13).reshape(4,3), -1).eval()
array([[ 0, 0, 0],
[ 4, 0, 0],
[ 7, 8, 0],
[10, 11, 12]])
>>> at.tril(np.arange(3*4*5).reshape(3, 4, 5)).eval()
array([[[ 0, 0, 0, 0, 0],
[ 5, 6, 0, 0, 0],
[10, 11, 12, 0, 0],
[15, 16, 17, 18, 0]],
[[20, 0, 0, 0, 0],
[25, 26, 0, 0, 0],
[30, 31, 32, 0, 0],
[35, 36, 37, 38, 0]],
[[40, 0, 0, 0, 0],
[45, 46, 0, 0, 0],
[50, 51, 52, 0, 0],
[55, 56, 57, 58, 0]]])
"""
return m * tri(m.shape[0], m.shape[1], k=k, dtype=m.dtype)
return m * tri(*m.shape[-2:], k=k, dtype=m.dtype)
def triu(m, k=0):
"""
Upper triangle of an array.
Return a copy of a matrix with the elements below the `k`-th diagonal
zeroed.
Return a copy of an array with the elements below the `k`-th diagonal
zeroed. For arrays with ``ndim`` exceeding 2, `triu` will apply to the
final two axes.
Please refer to the documentation for `tril` for further details.
......@@ -1234,10 +1261,32 @@ def triu(m, k=0):
--------
tril : Lower triangle of an array.
Examples
--------
>>> at.triu(np.arange(1,13).reshape(4,3), -1).eval()
array([[ 1, 2, 3],
[ 4, 5, 6],
[ 0, 8, 9],
[ 0, 0, 12]])
>>> at.triu(np.arange(3*4*5).reshape(3, 4, 5)).eval()
array([[[ 0, 1, 2, 3, 4],
[ 0, 6, 7, 8, 9],
[ 0, 0, 12, 13, 14],
[ 0, 0, 0, 18, 19]],
[[20, 21, 22, 23, 24],
[ 0, 26, 27, 28, 29],
[ 0, 0, 32, 33, 34],
[ 0, 0, 0, 38, 39]],
[[40, 41, 42, 43, 44],
[ 0, 46, 47, 48, 49],
[ 0, 0, 52, 53, 54],
[ 0, 0, 0, 58, 59]]])
"""
return m * (
constant(1, dtype=m.dtype) - tri(m.shape[0], m.shape[1], k=k - 1, dtype=m.dtype)
)
return m * (constant(1, dtype=m.dtype) - tri(*m.shape[-2:], k=k - 1, dtype=m.dtype))
def tril_indices(
......
......@@ -923,6 +923,24 @@ class TestTriangle:
assert np.allclose(result_indx, result_from)
assert result.dtype == np.dtype(dtype)
def check_l_batch(m, k=0):
m_symb = tensor3(dtype=m.dtype)
k_symb = iscalar()
f = function([m_symb, k_symb], tril(m_symb, k_symb))
for k in [-1, 0, 1]:
result = f(m, k)
assert np.allclose(result, np.tril(m, k))
assert result.dtype == np.dtype(dtype)
def check_u_batch(m):
m_symb = tensor3(dtype=m.dtype)
k_symb = iscalar()
f = function([m_symb, k_symb], triu(m_symb, k_symb))
for k in [-1, 0, 1]:
result = f(m, k)
assert np.allclose(result, np.triu(m, k))
assert result.dtype == np.dtype(dtype)
for dtype in ALL_DTYPES:
m = random_of_dtype((10, 10), dtype)
check_l(m, 0)
......@@ -942,6 +960,14 @@ class TestTriangle:
check_u(m, 1)
check_u(m, -1)
m = random_of_dtype((5, 5, 5), dtype)
check_l_batch(m)
check_u_batch(m)
m = random_of_dtype((5, 10, 5), dtype)
check_l_batch(m)
check_u_batch(m)
m = random_of_dtype((10,), dtype)
for fn in (triu_indices_from, tril_indices_from):
with pytest.raises(ValueError, match="must be two dimensional"):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论