提交 3efd4b7a authored 作者: Shawn Tan's avatar Shawn Tan

Modified test cases for `k != 0`

上级 6237660f
......@@ -392,7 +392,7 @@ def test_gpujoin_gpualloc():
def test_gpueye():
def check(dtype, N, M_=None):
def check(dtype, N, M_=None, k=0):
# Theano does not accept None as a tensor.
# So we must use a real value.
M = M_
......@@ -402,13 +402,14 @@ def test_gpueye():
M = N
N_symb = T.iscalar()
M_symb = T.iscalar()
k_symb = np.asarray(0)
out = T.eye(N_symb, M_symb, k_symb, dtype=dtype)
f = theano.function([N_symb, M_symb],
T.stack(out),
k_symb = T.iscalar()
out = T.eye(N_symb, M_symb, k_symb, dtype=dtype) + np.array(1).astype(dtype)
f = theano.function([N_symb, M_symb, k_symb],
out,
mode=mode_with_gpu)
result = np.asarray(f(N, M))
assert np.allclose(result, np.eye(N, M_, dtype=dtype))
result = np.asarray(f(N, M, k)) - np.array(1).astype(dtype)
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
assert result.dtype == np.dtype(dtype)
assert any([isinstance(node.op, GpuEye)
for node in f.maker.fgraph.toposort()])
......@@ -418,6 +419,15 @@ def test_gpueye():
# M != N, k = 0
yield check, dtype, 3, 5
yield check, dtype, 5, 3
# N == M, k != 0
yield check, dtype, 3, 3, 1
yield check, dtype, 3, 3, -1
# N < M, k != 0
yield check, dtype, 3, 5, 1
yield check, dtype, 3, 5, -1
# N > M, k != 0
yield check, dtype, 5, 3, 1
yield check, dtype, 5, 3, -1
def test_hostfromgpu_shape_i():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论