提交 06ba1976 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

More tests for sparse indexing.

上级 5e4efff0
...@@ -930,7 +930,11 @@ def test_size(): ...@@ -930,7 +930,11 @@ def test_size():
check() check()
def test_GetItem2D(): class Test_getitem(unittest.TestCase):
def setUp(self):
self.rng = numpy.random.RandomState(utt.fetch_seed())
def test_GetItem2D(self):
sparse_formats = ('csc', 'csr') sparse_formats = ('csc', 'csr')
for format in sparse_formats: for format in sparse_formats:
x = theano.sparse.matrix(format, name='x') x = theano.sparse.matrix(format, name='x')
...@@ -945,7 +949,7 @@ def test_GetItem2D(): ...@@ -945,7 +949,7 @@ def test_GetItem2D():
p = 10 p = 10
q = 15 q = 15
vx = as_sparse_format(numpy.random.binomial(1, 0.5, (100, 100)), vx = as_sparse_format(self.rng.binomial(1, 0.5, (100, 100)),
format).astype(theano.config.floatX) format).astype(theano.config.floatX)
#mode_no_debug = theano.compile.mode.get_default_mode() #mode_no_debug = theano.compile.mode.get_default_mode()
#if isinstance(mode_no_debug, theano.compile.DebugMode): #if isinstance(mode_no_debug, theano.compile.DebugMode):
...@@ -1016,8 +1020,46 @@ def test_GetItem2D(): ...@@ -1016,8 +1020,46 @@ def test_GetItem2D():
assert r9.shape == t9.shape assert r9.shape == t9.shape
assert numpy.all(r9.toarray() == t9.toarray()) assert numpy.all(r9.toarray() == t9.toarray())
#-----------------------------------------------------------
def test_GetItemScalar(): # Test mixing None and variables
f10 = theano.function([x, a, b], x[:a, :b])
r10 = f10(vx, p, q)
t10 = vx[:p, :q]
assert r10.shape == t10.shape
assert numpy.all(r10.toarray() == t10.toarray())
f11 = theano.function([x, a], x[:,a:])
r11 = f11(vx, p)
t11 = vx[:, p:]
assert r11.shape == t11.shape
assert numpy.all(r11.toarray() == t11.toarray())
#------------------------------------------------------------
# Invalid things
# The syntax is a bit awkward because assertRaises forbids
# the [] shortcut for getitem.
# x[a:b] is not accepted because we don't have sparse vectors
self.assertRaises(NotImplementedError,
x.__getitem__, (slice(a, b), c))
# x[a:b:step, c:d] is not accepted because scipy silently drops
# the step (!)
self.assertRaises(ValueError,
x.__getitem__, (slice(a, b, -1), slice(c, d)))
self.assertRaises(ValueError,
x.__getitem__, (slice(a, b), slice(c, d, 2)))
# Advanced indexing is not supported
self.assertRaises(ValueError,
x.__getitem__, (tensor.ivector('l'), slice(a, b)))
# Indexing with random things is not supported either
self.assertRaises(ValueError,
x.__getitem__, slice(tensor.fscalar('f'), None))
self.assertRaises(ValueError,
x.__getitem__, (slice(None), slice([1,3,4], None)))
def test_GetItemScalar(self):
sparse_formats = ('csc', 'csr') sparse_formats = ('csc', 'csr')
for format in sparse_formats: for format in sparse_formats:
x = theano.sparse.csc_matrix('x') x = theano.sparse.csc_matrix('x')
...@@ -1027,7 +1069,7 @@ def test_GetItemScalar(): ...@@ -1027,7 +1069,7 @@ def test_GetItemScalar():
m = 50 m = 50
n = 50 n = 50
vx = as_sparse_format(numpy.random.binomial(1, 0.5, (100, 100)), vx = as_sparse_format(self.rng.binomial(1, 0.5, (100, 100)),
format).astype(theano.config.floatX) format).astype(theano.config.floatX)
f1 = theano.function([x, a, b], x[a, b]) f1 = theano.function([x, a, b], x[a, b])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论