提交 505317ba authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Enable use of partial slices in sparse indexing

For instance, m[:,a:b], or m[a:,:b]
上级 65d27127
......@@ -679,35 +679,51 @@ class GetItem2d(gof.op.Op):
assert len(index) in [1, 2]
input_op = [x]
generic_None = theano.gof.Constant(theano.gof.generic, None)
for ind in index:
if isinstance(ind, slice):
# in case of slice is written in theano variable
start = ind.start
stop = ind.stop
# in case of slice is written in python int
if isinstance(start, int):
start = theano.tensor.constant(start)
if isinstance(stop, int):
stop = theano.tensor.constant(stop)
#in case of indexing using python int
#elif isinstance(ind,int):
# start = theano.tensor.constant(ind)
# stop = start + 1
#elif ind.ndim == 0:
# start = ind
# stop = ind + 1
if ind.step is not None:
raise ValueError((
"Using a slice with non-default step when "
"indexing into a sparse matrix is not supported. "),
ind, ind.step)
# If start or step are None, make them a Generic constant
# Else, they should be converted to Tensor Variables of
# dimension 1 and int/uint dtype.
if start is None:
start = generic_None
else:
if not isinstance(start, gof.Variable):
start = tensor.as_tensor_variable(start)
if not (start.ndim == 0 and start.dtype in tensor.discrete_dtypes):
raise ValueError((
"Impossible to index into a sparse matrix with "
"slice where start=%s" % start),
start.ndim, start.dtype)
if stop is None:
stop = generic_None
else:
if not isinstance(stop, gof.Variable):
stop = tensor.as_tensor_variable(stop)
if not (stop.ndim == 0 and stop.dtype in tensor.discrete_dtypes):
raise ValueError((
"Impossible to index into a sparse matrix with "
"slice where stop=%s" % stop),
stop.ndim, stop.dtype)
else:
raise NotImplemented(
raise NotImplementedError(
'Theano has no sparse vector' +
'Use X[a:b,c:d], X[a:b,c:c+1] or X[a:b] instead.')
input_op += [start, stop]
if len(index) == 1:
i = theano.gof.Constant(theano.gof.generic, None)
input_op += [i, i]
input_op += [generic_None, generic_None]
return gof.Apply(self, input_op, [x.type()])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论