提交 70b4f56e authored 作者: Li Yao's avatar Li Yao

indexing subtensor ops for sparse matrix 2nd commit

上级 495aa9d3
......@@ -188,9 +188,9 @@ class _sparse_py_operators:
if not isinstance(args, tuple):
args = args,
scalar_var = tensor.scalar(dtype='int32')
scalar_var = tensor.iscalar()
if len(args) is not 1:
if len(args) == 2:
scalar_arg_1 = (numpy.isscalar(args[0]) or
getattr(args[0], 'type', None) == scalar_var.type)
scalar_arg_2 = (numpy.isscalar(args[1]) or
......@@ -652,7 +652,14 @@ class GetItem2d(gof.op.Op):
If you want to take only one element of a sparse matrix see the class GetItemScalar
that return a tensor scalar.
:note: that subtensor selection always returns a matrix, even when one index is a scalar.
:note:
that subtensor selection always returns a matrix so indexing with [a:b, c:d] is forced.
If one index is a scalar, e.g. x[a:b, c] and x[a, b:c], generate an error. Use instead
x[a:b, c:c+1] and x[a:a+1, b:c].
The above indexing methods are not supported because the rval would be a sparse
matrix rather than a sparse vector, which is a deviation from numpy indexing rule.
This decision is made largely for keeping the consistency between numpy and theano.
Subjected to modification when sparse vector is supported.
"""
def __eq__(self, other):
return (type(self) == type(other))
......@@ -683,15 +690,17 @@ class GetItem2d(gof.op.Op):
if isinstance(stop,int):
stop = theano.tensor.constant(stop)
# in case of indexing using python int
elif isinstance(ind,int):
start = theano.tensor.constant(ind)
stop = start + 1
elif ind.ndim == 0:
start = ind
stop = ind + 1
#in case of indexing using python int
#elif isinstance(ind,int):
# start = theano.tensor.constant(ind)
# stop = start + 1
#elif ind.ndim == 0:
# start = ind
# stop = ind + 1
else:
raise NotImplemented()
raise NotImplemented('Theano has no sparse vector'+
'Use X[a:b,c:d], X[a:b,c:c+1] or X[a:b] instead.')
input_op += [start, stop]
if len(index)==1:
i = theano.gof.Constant(theano.gof.generic, None)
......
......@@ -930,7 +930,7 @@ def test_size():
check()
def test_GetItem2D():
sparse_formats = ('csc','csr')
sparse_formats = ('csc', 'csr')
for format in sparse_formats:
x = theano.sparse.matrix(format)
a = theano.tensor.iscalar()
......@@ -944,7 +944,8 @@ def test_GetItem2D():
p = 10
q = 15
vx = as_sparse_format(numpy.random.binomial(1, 0.5, (100, 100)), format).astype(theano.config.floatX)
vx = as_sparse_format(numpy.random.binomial(1, 0.5, (100, 100)), format).astype(
theano.config.floatX)
#mode_no_debug = theano.compile.mode.get_default_mode()
#if isinstance(mode_no_debug, theano.compile.DebugMode):
......@@ -955,19 +956,38 @@ def test_GetItem2D():
t1 = vx[m:n, p:q]
assert r1.shape == t1.shape
assert numpy.all(t1.toarray() == r1.toarray())
""""
Important: based on a discussion with both Fred and James
The following indexing methods is not supported because the rval would be a sparse
matrix rather than a sparse vector, which is a deviation from numpy indexing rule.
This decision is made largely for keeping the consistency between numpy and theano.
f2 = theano.function([x, a, b, c], x[a:b, c])
r2 = f2(vx, m, n, p)
t2 = vx[m:n,p]
t2 = vx[m:n, p]
assert r2.shape == t2.shape
assert numpy.all(t2.toarray() == r2.toarray())
f3 = theano.function([x, a, b, c], x[a, b:c])
r3 = f3(vx, m, n, p)
t3 = vx[m,n:p]
t3 = vx[m, n:p]
assert r3.shape == t3.shape
assert numpy.all(t3.toarray() == r3.toarray())
f5 = theano.function([x], x[1:2,3])
r5 = f5(vx)
t5 = vx[1:2, 3]
assert r5.shape == t5.shape
assert numpy.all(r5.toarray() == t5.toarray())
f7 = theano.function([x], x[50])
r7 = f7(vx)
t7 = vx[50]
assert r7.shape == t7.shape
assert numpy.all(r7.toarray() == t7.toarray())
"""
f4 = theano.function([x, a, b], x[a:b])
r4 = f4(vx, m, n)
t4 = vx[m:n]
......@@ -975,39 +995,29 @@ def test_GetItem2D():
assert numpy.all(t4.toarray() == r4.toarray())
#-----------------------------------------------------------
# test cases using int indexing instead of theano variable
f5 = theano.function([x],x[1:2,3])
r5 = f5(vx)
t5 = vx[1:2,3]
assert r5.shape == t5.shape
assert numpy.all(r5.toarray() == t5.toarray())
f6 = theano.function([x],x[1:10,10:20])
f6 = theano.function([x], x[1:10,10:20])
r6 = f6(vx)
t6 = vx[1:10,10:20]
assert r6.shape == t6.shape
assert numpy.all(r6.toarray() == t6.toarray())
f7 = theano.function([x],x[50])
r7 = f7(vx)
t7 = vx[50]
assert r7.shape == t7.shape
assert numpy.all(r7.toarray() == t7.toarray())
#----------------------------------------------------------
# test cases with indexing both with theano variable and int
f8 = theano.function([x,a,b],x[a:b,10:20])
r8 = f8(vx,m,n)
t8 = vx[m:n,10:20]
f8 = theano.function([x,a,b], x[a:b,10:20])
r8 = f8(vx, m, n)
t8 = vx[m:n, 10:20]
assert r8.shape == t8.shape
assert numpy.all(r8.toarray() == t8.toarray())
f9 = theano.function([x,a,b],x[1:a,1:b])
r9 = f9(vx,p,q)
t9 = vx[1:p,1:q]
f9 = theano.function([x, a, b],x[1:a, 1:b])
r9 = f9(vx, p, q)
t9 = vx[1:p, 1:q]
assert r9.shape == t9.shape
assert numpy.all(r9.toarray() == t9.toarray())
def test_GetItemScalar():
sparse_formats = ('csc','csr')
sparse_formats = ('csc', 'csr')
for format in sparse_formats:
x = theano.sparse.csc_matrix('x')
a = theano.tensor.iscalar()
......@@ -1016,29 +1026,30 @@ def test_GetItemScalar():
m = 50
n = 50
vx = as_sparse_format(numpy.random.binomial(1, 0.5, (100, 100)), format).astype(theano.config.floatX)
vx = as_sparse_format(numpy.random.binomial(1, 0.5, (100, 100)), format).astype(
theano.config.floatX)
f1 = theano.function([x, a, b], x[a,b])
f1 = theano.function([x, a, b], x[a, b])
r1 = f1(vx, 10, 10)
t1 = vx[10,10]
t1 = vx[10, 10]
assert r1.shape == t1.shape
assert numpy.all(t1 == r1)
f2 = theano.function([x, a], x[50,a])
f2 = theano.function([x, a], x[50, a])
r2 = f2(vx, m)
t2 = vx[50,m]
t2 = vx[50, m]
assert r2.shape == t2.shape
assert numpy.all(t2 == r2)
f3 = theano.function([x, a], x[a,50])
f3 = theano.function([x, a], x[a, 50])
r3 = f3(vx, m)
t3 = vx[m,50]
t3 = vx[m, 50]
assert r3.shape == t3.shape
assert numpy.all(t3 == r3)
f4 = theano.function([x], x[50,50])
f4 = theano.function([x], x[50, 50])
r4 = f4(vx)
t4 = vx[m,n]
t4 = vx[m, n]
assert r3.shape == t3.shape
assert numpy.all(t4 == r4)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论