提交 3a9e8326 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Cosmetic changes

上级 668f41e0
......@@ -169,15 +169,17 @@ class test_structureddot(unittest.TestCase):
# iterate for a few different random graph patterns
for i in range(10):
spmat = sp.csc_matrix((4,6), dtype=sparse_dtype)
for i in range(5):
for k in range(5):
# set non-zeros in random locations (row x, col y)
x = numpy.floor(numpy.random.rand()*spmat.shape[0])
y = numpy.floor(numpy.random.rand()*spmat.shape[1])
spmat[x,y] = numpy.random.rand()*10
spmat = sp.csc_matrix(spmat)
kerns = tensor.Tensor(broadcastable=[False], dtype=sparse_dtype)('kerns')
images = tensor.Tensor(broadcastable=[False, False], dtype=dense_dtype)('images')
kerns = tensor.Tensor(broadcastable=[False],
dtype=sparse_dtype)('kerns')
images = tensor.Tensor(broadcastable=[False, False],
dtype=dense_dtype)('images')
output_dtype = theano.scalar.upcast(sparse_dtype, dense_dtype)
##
......@@ -186,7 +188,8 @@ class test_structureddot(unittest.TestCase):
# build symbolic theano graph
def buildgraphCSC(kerns,images):
csc = CSC(kerns, spmat.indices[:spmat.size], spmat.indptr, spmat.shape)
csc = CSC(kerns, spmat.indices[:spmat.size],
spmat.indptr, spmat.shape)
assert csc.type.dtype == sparse_dtype
rval = structured_dot(csc, images.T)
assert rval.type.dtype == output_dtype
......@@ -197,8 +200,12 @@ class test_structureddot(unittest.TestCase):
# compute theano outputs
kernvals = spmat.data[:spmat.size]
imvals = 1.0 + 1.0 * numpy.array(numpy.arange(bsize*spmat.shape[1]).\
imvals = 1.0 + 1.0 * numpy.array(
numpy.arange(bsize*spmat.shape[1]).\
reshape(bsize,spmat.shape[1]), dtype=dense_dtype)
#print('dense_dtype=%s' % dense_dtype)
#print('sparse_dtype=%s' % sparse_dtype)
#print('i=%s' % i)
print 'kerntype', str(kernvals.dtype), kernvals.dtype.num
outvals = f(kernvals,imvals)
print 'YAY'
......@@ -210,9 +217,10 @@ class test_structureddot(unittest.TestCase):
assert _is_dense(c)
assert str(outvals.dtype) == output_dtype
assert numpy.all(numpy.abs(outvals -
numpy.array(c, dtype=output_dtype)) < 1e-4)
numpy.array(c, dtype=output_dtype)) < 1e-4)
if sparse_dtype.startswith('float') and dense_dtype.startswith('float'):
if (sparse_dtype.startswith('float') and
dense_dtype.startswith('float')):
utt.verify_grad(buildgraphCSC,
[kernvals, imvals])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论