提交 c1933179 authored 作者: bergstra@ip05.m's avatar bergstra@ip05.m

adding tests for a variety of sparse dtypes

上级 72528716
...@@ -696,7 +696,7 @@ class StructuredDot(gof.Op): ...@@ -696,7 +696,7 @@ class StructuredDot(gof.Op):
else: else:
raise Exception("a.shape=%s, b.shape=%s, variable.shape=%s ??? I have no idea why") raise Exception("a.shape=%s, b.shape=%s, variable.shape=%s ??? I have no idea why")
## Commenting this out because variable should be a numpy.ndarray since the assert above ## Commenting this out because variable should be a numpy.ndarray since the "assert _is_dense(variable)" above
## (JB 20090109) ## (JB 20090109)
# out[0] = numpy.asarray(variable) #TODO: fix this really bad implementation # out[0] = numpy.asarray(variable) #TODO: fix this really bad implementation
# #
...@@ -714,6 +714,7 @@ def structured_dot(x, y): ...@@ -714,6 +714,7 @@ def structured_dot(x, y):
""" """
@todo: Maybe the triple-transposition formulation (when x is dense) @todo: Maybe the triple-transposition formulation (when x is dense)
is slow. See if there is a direct way to do this. is slow. See if there is a direct way to do this.
(JB 20090528: Transposing tensors and sparse matrices is constant-time, inplace, and fast.)
""" """
if hasattr(x, 'getnnz'): x = as_sparse_variable(x) if hasattr(x, 'getnnz'): x = as_sparse_variable(x)
if hasattr(y, 'getnnz'): y = as_sparse_variable(y) if hasattr(y, 'getnnz'): y = as_sparse_variable(y)
......
...@@ -161,18 +161,25 @@ class test_structureddot(unittest.TestCase): ...@@ -161,18 +161,25 @@ class test_structureddot(unittest.TestCase):
def test_structuredot(self): def test_structuredot(self):
bsize = 2 bsize = 2
typenames = 'int8', 'int16', 'int32', 'int64', 'float32', 'float64', 'complex64', 'complex128'
# iterate 10 times just to make sure (cannot get this wrong !) for sparse_dtype in typenames:
for dense_dtype in typenames:
# iterate for a few different random graph patterns
for i in range(10): for i in range(10):
spmat = sp.lil_matrix((4,6)) spmat = sp.lil_matrix((4,6), dtype=sparse_dtype)
for i in range(5): for i in range(5):
# set non-zeros in random locations (row x, col y)
x = numpy.floor(numpy.random.rand()*spmat.shape[0]) x = numpy.floor(numpy.random.rand()*spmat.shape[0])
y = numpy.floor(numpy.random.rand()*spmat.shape[1]) y = numpy.floor(numpy.random.rand()*spmat.shape[1])
spmat[x,y] = numpy.random.rand()*10 spmat[x,y] = numpy.random.rand()*10
spmat = sp.csc_matrix(spmat) spmat = sp.csc_matrix(spmat)
kerns = tensor.dvector('kerns') kerns = tensor.Tensor(broadcastable=[False], dtype=sparse_dtype)('kerns')
images = tensor.dmatrix('images') images = tensor.Tensor(broadcastable=[False, False], dtype=dense_dtype)('images')
output_dtype = theano.scalar.upcast(sparse_dtype, dense_dtype)
assert output_dtype in (sparse_dtype, dense_dtype)
## ##
# Test compressed-sparse column matrices ### # Test compressed-sparse column matrices ###
...@@ -181,6 +188,7 @@ class test_structureddot(unittest.TestCase): ...@@ -181,6 +188,7 @@ class test_structureddot(unittest.TestCase):
# build symbolic theano graph # build symbolic theano graph
def buildgraphCSC(kerns,images): def buildgraphCSC(kerns,images):
csc = CSC(kerns, spmat.indices[:spmat.size], spmat.indptr, spmat.shape) csc = CSC(kerns, spmat.indices[:spmat.size], spmat.indptr, spmat.shape)
assert csc.type.dtype == output_dtype
return structured_dot(csc, images.T) return structured_dot(csc, images.T)
out = buildgraphCSC(kerns,images) out = buildgraphCSC(kerns,images)
f = theano.function([kerns,images], out) f = theano.function([kerns,images], out)
...@@ -192,7 +200,10 @@ class test_structureddot(unittest.TestCase): ...@@ -192,7 +200,10 @@ class test_structureddot(unittest.TestCase):
c = spmat * (imvals.T) c = spmat * (imvals.T)
assert _is_dense(c) assert _is_dense(c)
assert numpy.all(outvals == c) assert numpy.all(outvals == c)
assert str(outvals.dtype) == output_dtype
assert c.dtype == outvals.dtype
if sparse_dtype.startswith('float') and dense_dtype.startswith('float'):
utt.verify_grad(buildgraphCSC, [kernvals,imvals]) utt.verify_grad(buildgraphCSC, [kernvals,imvals])
## ##
...@@ -214,7 +225,11 @@ class test_structureddot(unittest.TestCase): ...@@ -214,7 +225,11 @@ class test_structureddot(unittest.TestCase):
c = spmat * (imvals.T) c = spmat * (imvals.T)
assert _is_dense(c) assert _is_dense(c)
assert numpy.all(outvals == c) assert numpy.all(outvals == c)
assert str(outvals.dtype) == output_dtype
assert c.dtype == outvals.dtype
# we could test more, but hopefully this suffices?
if sparse_dtype.startswith('float') and dense_dtype.startswith('float'):
utt.verify_grad( buildgraphCSR, [kernvals,imvals]) utt.verify_grad( buildgraphCSR, [kernvals,imvals])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论