提交 b4b82e6b authored 作者: ChienliMa's avatar ChienliMa

Sparse matrix slicing now support step. #2379

上级 5ec4c302
...@@ -1219,15 +1219,22 @@ class GetItem2d(gof.op.Op): ...@@ -1219,15 +1219,22 @@ class GetItem2d(gof.op.Op):
# in case of slice is written in theano variable # in case of slice is written in theano variable
start = ind.start start = ind.start
stop = ind.stop stop = ind.stop
if ind.step is not None: step = ind.step
# If start or stop or step are None, make them a Generic
# constant. Else, they should be converted to Tensor Variables
# of dimension 1 and int/uint dtype.
if ind.step is None:
step = generic_None
else:
if not isinstance(step, gof.Variable):
step = tensor.as_tensor_variable(step)
if not (step.ndim == 0 and step.dtype in
tensor.discrete_dtypes):
raise ValueError(( raise ValueError((
"Using a slice with non-default step when " "Impossible to index into a sparse matrix with "
"indexing into a sparse matrix is not supported. "), "slice where start=%s" % step),
ind, ind.step) step.ndim, step.dtype)
# If start or stop are None, make them a Generic constant
# Else, they should be converted to Tensor Variables of
# dimension 1 and int/uint dtype.
if start is None: if start is None:
start = generic_None start = generic_None
else: else:
...@@ -1262,15 +1269,15 @@ class GetItem2d(gof.op.Op): ...@@ -1262,15 +1269,15 @@ class GetItem2d(gof.op.Op):
raise ValueError(( raise ValueError((
'Advanced indexing is not implemented for sparse ' 'Advanced indexing is not implemented for sparse '
'matrices. Argument not supported: %s' % ind)) 'matrices. Argument not supported: %s' % ind))
input_op += [start, stop] input_op += [start, stop, step]
if len(index) == 1: if len(index) == 1:
input_op += [generic_None, generic_None] input_op += [generic_None, generic_None, generic_None]
return gof.Apply(self, input_op, [x.type()]) return gof.Apply(self, input_op, [x.type()])
def perform(self, node, (x, start1, stop1, start2, stop2), (out, )): def perform(self, node, (x, start1, stop1, step1, start2, stop2, step2), (out, )):
assert _is_sparse(x) assert _is_sparse(x)
out[0] = x[start1:stop1, start2:stop2] out[0] = x[start1:stop1:step1, start2:stop2:step2]
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
......
...@@ -2103,6 +2103,7 @@ class Test_getitem(unittest.TestCase): ...@@ -2103,6 +2103,7 @@ class Test_getitem(unittest.TestCase):
verify_grad_sparse(op_with_fixed_index, x_val) verify_grad_sparse(op_with_fixed_index, x_val)
def test_GetItem2D(self): def test_GetItem2D(self):
sparse_formats = ('csc', 'csr') sparse_formats = ('csc', 'csr')
for format in sparse_formats: for format in sparse_formats:
...@@ -2111,21 +2112,25 @@ class Test_getitem(unittest.TestCase): ...@@ -2111,21 +2112,25 @@ class Test_getitem(unittest.TestCase):
b = theano.tensor.iscalar('b') b = theano.tensor.iscalar('b')
c = theano.tensor.iscalar('c') c = theano.tensor.iscalar('c')
d = theano.tensor.iscalar('d') d = theano.tensor.iscalar('d')
e = theano.tensor.iscalar('e')
f = theano.tensor.iscalar('f')
# index # index
m = 1 m = 1
n = 5 n = 5
p = 10 p = 10
q = 15 q = 15
j = 2
k = 3
vx = as_sparse_format(self.rng.binomial(1, 0.5, (100, 97)), vx = as_sparse_format(self.rng.binomial(1, 0.5, (100, 97)),
format).astype(theano.config.floatX) format).astype(theano.config.floatX)
#mode_no_debug = theano.compile.mode.get_default_mode() #mode_no_debug = theano.compile.mode.get_default_mode()
#if isinstance(mode_no_debug, theano.compile.DebugMode): #if isinstance(mode_no_debug, theano.compile.DebugMode):
# mode_no_debug = 'FAST_RUN' # mode_no_debug = 'FAST_RUN'
f1 = theano.function([x, a, b, c, d], x[a:b, c:d]) f1 = theano.function([x, a, b, c, d, e, f], x[a:b:e, c:d:f])
r1 = f1(vx, m, n, p, q) r1 = f1(vx, m, n, p, q, j, k)
t1 = vx[m:n, p:q] t1 = vx[m:n:j, p:q:k]
assert r1.shape == t1.shape assert r1.shape == t1.shape
assert numpy.all(t1.toarray() == r1.toarray()) assert numpy.all(t1.toarray() == r1.toarray())
...@@ -2161,31 +2166,31 @@ class Test_getitem(unittest.TestCase): ...@@ -2161,31 +2166,31 @@ class Test_getitem(unittest.TestCase):
assert numpy.all(r7.toarray() == t7.toarray()) assert numpy.all(r7.toarray() == t7.toarray())
""" """
f4 = theano.function([x, a, b], x[a:b]) f4 = theano.function([x, a, b, e], x[a:b:e])
r4 = f4(vx, m, n) r4 = f4(vx, m, n, j)
t4 = vx[m:n] t4 = vx[m:n:j]
assert r4.shape == t4.shape assert r4.shape == t4.shape
assert numpy.all(t4.toarray() == r4.toarray()) assert numpy.all(t4.toarray() == r4.toarray())
#----------------------------------------------------------- #-----------------------------------------------------------
# test cases using int indexing instead of theano variable # test cases using int indexing instead of theano variable
f6 = theano.function([x], x[1:10, 10:20]) f6 = theano.function([x], x[1:10:1, 10:20:2])
r6 = f6(vx) r6 = f6(vx)
t6 = vx[1:10, 10:20] t6 = vx[1:10:1, 10:20:2]
assert r6.shape == t6.shape assert r6.shape == t6.shape
assert numpy.all(r6.toarray() == t6.toarray()) assert numpy.all(r6.toarray() == t6.toarray())
#---------------------------------------------------------- #----------------------------------------------------------
# test cases with indexing both with theano variable and int # test cases with indexing both with theano variable and int
f8 = theano.function([x, a, b], x[a:b, 10:20]) f8 = theano.function([x, a, b, e], x[a:b:e, 10:20:1])
r8 = f8(vx, m, n) r8 = f8(vx, m, n, j)
t8 = vx[m:n, 10:20] t8 = vx[m:n:j, 10:20:1]
assert r8.shape == t8.shape assert r8.shape == t8.shape
assert numpy.all(r8.toarray() == t8.toarray()) assert numpy.all(r8.toarray() == t8.toarray())
f9 = theano.function([x, a, b], x[1:a, 1:b]) f9 = theano.function([x, a, b], x[1:a:2, 1:b:2])
r9 = f9(vx, p, q) r9 = f9(vx, p, q)
t9 = vx[1:p, 1:q] t9 = vx[1:p:2, 1:q:2]
assert r9.shape == t9.shape assert r9.shape == t9.shape
assert numpy.all(r9.toarray() == t9.toarray()) assert numpy.all(r9.toarray() == t9.toarray())
...@@ -2219,12 +2224,12 @@ class Test_getitem(unittest.TestCase): ...@@ -2219,12 +2224,12 @@ class Test_getitem(unittest.TestCase):
self.assertRaises(NotImplementedError, self.assertRaises(NotImplementedError,
x.__getitem__, (slice(a, b), c)) x.__getitem__, (slice(a, b), c))
# x[a:b:step, c:d] is not accepted because scipy silently drops # # x[a:b:step, c:d] is not accepted because scipy silently drops
# the step (!) # # the step (!)
self.assertRaises(ValueError, # self.assertRaises(ValueError,
x.__getitem__, (slice(a, b, -1), slice(c, d))) # x.__getitem__, (slice(a, b, -1), slice(c, d)))
self.assertRaises(ValueError, # self.assertRaises(ValueError,
x.__getitem__, (slice(a, b), slice(c, d, 2))) # x.__getitem__, (slice(a, b), slice(c, d, 2)))
# Advanced indexing is not supported # Advanced indexing is not supported
self.assertRaises(ValueError, self.assertRaises(ValueError,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论