提交 f485be15 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Reduce number of parametrizations in `test_tril_triu`

上级 43b3f401
......@@ -1044,7 +1044,7 @@ class TestTriangle:
assert np.allclose(result_indx, result_from)
assert result.dtype == np.dtype(dtype)
def check_l_batch(m, k=0):
def check_l_batch(m):
m_symb = tensor3(dtype=m.dtype)
k_symb = iscalar()
f = function([m_symb, k_symb], tril(m_symb, k_symb))
......@@ -1062,28 +1062,31 @@ class TestTriangle:
assert np.allclose(result, np.triu(m, k))
assert result.dtype == np.dtype(dtype)
for dtype in ["int32", "int64", "float32", "float64", "uint16", "complex64"]:
for dtype in ["int32", "int64", "uint16"]:
m = random_of_dtype((10, 10), dtype)
check_l(m, 0)
check_l(m, 1)
check_l(m, -1)
m = random_of_dtype((10, 5), dtype)
check_u(m, 0)
check_u(m, 1)
check_u(m, -1)
m = random_of_dtype((10, 5), dtype)
check_l(m, 0)
check_l(m, 1)
check_l(m, -1)
m = random_of_dtype((5, 5, 5), dtype)
check_l_batch(m)
check_u_batch(m)
for dtype in ["float32", "float64", "complex64"]:
m = random_of_dtype((10, 10), dtype)
check_u(m, 0)
check_u(m, 1)
check_u(m, -1)
m = random_of_dtype((5, 5, 5), dtype)
check_l_batch(m)
check_u_batch(m)
m = random_of_dtype((10, 5), dtype)
check_l(m, 0)
check_l(m, 1)
check_l(m, -1)
m = random_of_dtype((5, 10, 5), dtype)
check_l_batch(m)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论