提交 0339e02d authored 作者: nouiz's avatar nouiz

Merge pull request #451 from lamblin/fix_test_diag

Fix dtype of input value in test_diag
...@@ -470,22 +470,19 @@ def test_diag(): ...@@ -470,22 +470,19 @@ def test_diag():
f = theano.function([m], d) f = theano.function([m], d)
f2 = theano.function([m], d.shape) f2 = theano.function([m], d.shape)
for K in 1, 5: for K in 1, 5:
np_matrix = numpy.asarray(numpy.reshape(range(K**2),(K,K)),dtype='float64') np_matrix = numpy.asarray(numpy.reshape(range(K**2),(K,K)),
dtype=theano.config.floatX)
diag = numpy.diagonal(np_matrix) diag = numpy.diagonal(np_matrix)
sp_matrix = scipy.sparse.csc_matrix(np_matrix) sp_matrix = scipy.sparse.csc_matrix(np_matrix)
assert numpy.all(diag == f(sp_matrix)) assert numpy.all(diag == f(sp_matrix))
assert f2(sp_matrix) == diag.shape assert f2(sp_matrix) == diag.shape
def test_square_diagonal(): def test_square_diagonal():
for K in 1, 5: for K in 1, 5:
d = tensor.ivector() d = tensor.ivector()
sd = sp.square_diagonal(d) sd = sp.square_diagonal(d)
f = theano.function([d], sd) f = theano.function([d], sd)
n = numpy.zeros((K,K), dtype='int32') n = numpy.zeros((K,K), dtype='int32')
for i in range(K): for i in range(K):
n[i,i] = i n[i,i] = i
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论