提交 652d9e7a authored 作者: Jeremiah Lowin's avatar Jeremiah Lowin

add triangle tests

上级 c819f527
......@@ -32,7 +32,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
tensor4, permute_row_elements, Flatten, fmatrix, fscalars, grad,
inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq,
Reshape, row, scalar, scalars, second, smallest, stack, sub, Tensor,
tensor_copy, tensordot, TensorType, unbroadcast,
tensor_copy, tensordot, TensorType, Tri, tri, tril, triu, unbroadcast,
var, Join, shape, MaxAndArgmax, lscalar, zvector, exp,
get_scalar_constant_value, ivector, reshape, scalar_from_tensor, scal,
iscalars, arange, dscalars, fvector, imatrix, numeric_grad,
......@@ -1825,6 +1825,76 @@ def test_eye():
yield check, dtype, 5, 3, -1
def test_tri():
def check(dtype, N, M_=None, k=0):
# Theano does not accept None as a tensor.
# So we must use a real value.
M = M_
# Currently DebugMode does not support None as inputs even if this is
# allowed.
if M is None and theano.config.mode in ['DebugMode', 'DEBUG_MODE']:
M = N
N_symb = tensor.iscalar()
M_symb = tensor.iscalar()
k_symb = tensor.iscalar()
f = function([N_symb, M_symb, k_symb],
tri(N_symb, M_symb, k_symb, dtype=dtype))
result = f(N, M, k)
assert numpy.allclose(result, numpy.tri(N, M_, k, dtype=dtype))
assert result.dtype == numpy.dtype(dtype)
for dtype in ALL_DTYPES:
yield check, dtype, 3
# M != N, k = 0
yield check, dtype, 3, 5
yield check, dtype, 5, 3
# N == M, k != 0
yield check, dtype, 3, 3, 1
yield check, dtype, 3, 3, -1
# N < M, k != 0
yield check, dtype, 3, 5, 1
yield check, dtype, 3, 5, -1
# N > M, k != 0
yield check, dtype, 5, 3, 1
yield check, dtype, 5, 3, -1
def test_tril_triu():
def check_l(m, k=0):
m_symb = matrix(dtype=m.dtype)
k_symb = iscalar()
f = function([m_symb, k_symb], tril(m_symb, k_symb))
result = f(m, k)
assert numpy.allclose(result, numpy.tril(m, k))
assert result.dtype == numpy.dtype(dtype)
def check_u(m, k=0):
m_symb = matrix(dtype=m.dtype)
k_symb = iscalar()
f = function([m_symb, k_symb], triu(m_symb, k_symb))
result = f(m, k)
assert numpy.allclose(result, numpy.triu(m, k))
assert result.dtype == numpy.dtype(dtype)
for dtype in ALL_DTYPES:
m = rand_of_dtype((10, 10), dtype)
yield check_l, m, 0
yield check_l, m, 1
yield check_l, m, -1
yield check_u, m, 0
yield check_u, m, 1
yield check_u, m, -1
m = rand_of_dtype((10, 5), dtype)
yield check_l, m, 0
yield check_l, m, 1
yield check_l, m, -1
yield check_u, m, 0
yield check_u, m, 1
yield check_u, m, -1
def test_identity():
def check(dtype):
obj = rand_of_dtype((2,), dtype)
......@@ -6470,6 +6540,22 @@ class TestInferShape(utt.InferShapeTester):
[Eye()(aiscal, biscal, ciscal)],
[3, 5, 0], Eye)
# Tri
aiscal = iscalar()
biscal = iscalar()
ciscal = iscalar()
self._compile_and_check([aiscal, biscal, ciscal],
[Tri()(aiscal, biscal, ciscal)],
[4, 4, 0], Tri)
self._compile_and_check([aiscal, biscal, ciscal],
[Tri()(aiscal, biscal, ciscal)],
[4, 5, 0], Tri)
self._compile_and_check([aiscal, biscal, ciscal],
[Tri()(aiscal, biscal, ciscal)],
[3, 5, 0], Tri)
# Diagonal
atens3 = tensor3()
atens3_val = rand(4, 5, 3)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论