提交 0339e02d authored 作者: nouiz's avatar nouiz

Merge pull request #451 from lamblin/fix_test_diag

Fix dtype of input value in test_diag
......@@ -470,22 +470,19 @@ def test_diag():
f = theano.function([m], d)
f2 = theano.function([m], d.shape)
for K in 1, 5:
np_matrix = numpy.asarray(numpy.reshape(range(K**2),(K,K)),dtype='float64')
np_matrix = numpy.asarray(numpy.reshape(range(K**2),(K,K)),
dtype=theano.config.floatX)
diag = numpy.diagonal(np_matrix)
sp_matrix = scipy.sparse.csc_matrix(np_matrix)
assert numpy.all(diag == f(sp_matrix))
assert f2(sp_matrix) == diag.shape
def test_square_diagonal():
for K in 1, 5:
d = tensor.ivector()
sd = sp.square_diagonal(d)
f = theano.function([d], sd)
n = numpy.zeros((K,K), dtype='int32')
for i in range(K):
n[i,i] = i
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论