提交 84ac684c authored 作者: lamblin's avatar lamblin

Merge pull request #412 from lamblin/sparse_indexing

Sparse indexing
...@@ -1330,11 +1330,16 @@ class _Linker(gof.link.LocalLinker): ...@@ -1330,11 +1330,16 @@ class _Linker(gof.link.LocalLinker):
r_vals_initialized = [] r_vals_initialized = []
for r in storage_map: for r in storage_map:
if (r.owner is None): if (r.owner is None):
if (storage_map[r][0] is None):
raise Exception('Missing input', r)
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
# None may be a valid input value (for instance,
# for a Generic object). We only want to raise
# an error if it is not valid.
if (storage_map[r][0] is None):
raise InvalidValueError(r, storage_map[r][0],
hint="Graph Input '%s' is missing" % str(r))
raise InvalidValueError(r, storage_map[r][0], raise InvalidValueError(r, storage_map[r][0],
hint="Graph Input '%s' is missing" % str(r)) hint=("Graph Input '%s' has invalid value "
"%s" % (r, storage_map[r][0])))
r_vals[r] = storage_map[r][0] r_vals[r] = storage_map[r][0]
storage_map[r][0] = None storage_map[r][0] = None
r_vals_initialized.append(r) r_vals_initialized.append(r)
...@@ -1577,7 +1582,8 @@ class _Linker(gof.link.LocalLinker): ...@@ -1577,7 +1582,8 @@ class _Linker(gof.link.LocalLinker):
#print storage_map #print storage_map
for r in storage_map: for r in storage_map:
if (r.owner is None): if (r.owner is None):
assert storage_map[r][0] is not None if not r.type.is_valid_value(None):
assert storage_map[r][0] is not None
############### ###############
......
...@@ -391,7 +391,11 @@ class Constant(Value): ...@@ -391,7 +391,11 @@ class Constant(Value):
def __str__(self): def __str__(self):
if self.name is not None: if self.name is not None:
return self.name return self.name
return str(self.data) #+ "::" + str(self.type) else:
name = str(self.data)
if len(name) > 20:
name = name[:10] + '...' + name[-10]
return 'Constant{%s}' % name
def clone(self): def clone(self):
""" """
We clone this object, but we don't clone the data to lower memory requirement We clone this object, but we don't clone the data to lower memory requirement
......
...@@ -423,4 +423,7 @@ class Generic(SingletonType): ...@@ -423,4 +423,7 @@ class Generic(SingletonType):
Py_INCREF(py_%(name)s); Py_INCREF(py_%(name)s);
""" % locals() """ % locals()
def __str__(self):
return self.__class__.__name__
generic = Generic() generic = Generic()
...@@ -188,13 +188,11 @@ class _sparse_py_operators: ...@@ -188,13 +188,11 @@ class _sparse_py_operators:
if not isinstance(args, tuple): if not isinstance(args, tuple):
args = args, args = args,
scalar_var = tensor.iscalar()
if len(args) == 2: if len(args) == 2:
scalar_arg_1 = (numpy.isscalar(args[0]) or scalar_arg_1 = (numpy.isscalar(args[0]) or
getattr(args[0], 'type', None) == scalar_var.type) getattr(args[0], 'type', None) == tensor.iscalar)
scalar_arg_2 = (numpy.isscalar(args[1]) or scalar_arg_2 = (numpy.isscalar(args[1]) or
getattr(args[1], 'type', None) == scalar_var.type) getattr(args[1], 'type', None) == tensor.iscalar)
if scalar_arg_1 and scalar_arg_2: if scalar_arg_1 and scalar_arg_2:
ret = get_item_scalar(self, args) ret = get_item_scalar(self, args)
else: else:
...@@ -202,8 +200,8 @@ class _sparse_py_operators: ...@@ -202,8 +200,8 @@ class _sparse_py_operators:
else: else:
ret = get_item_2d(self, args) ret = get_item_2d(self, args)
return ret return ret
class SparseVariable(gof.Variable, _sparse_py_operators): class SparseVariable(gof.Variable, _sparse_py_operators):
dtype = property(lambda self: self.type.dtype) dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format) format = property(lambda self: self.type.format)
...@@ -681,35 +679,57 @@ class GetItem2d(gof.op.Op): ...@@ -681,35 +679,57 @@ class GetItem2d(gof.op.Op):
assert len(index) in [1, 2] assert len(index) in [1, 2]
input_op = [x] input_op = [x]
generic_None = theano.gof.Constant(theano.gof.generic, None)
for ind in index: for ind in index:
if isinstance(ind, slice): if isinstance(ind, slice):
# in case of slice is written in theano variable # in case of slice is written in theano variable
start = ind.start start = ind.start
stop = ind.stop stop = ind.stop
if ind.step is not None:
# in case of slice is written in python int raise ValueError((
if isinstance(start, int): "Using a slice with non-default step when "
start = theano.tensor.constant(start) "indexing into a sparse matrix is not supported. "),
if isinstance(stop, int): ind, ind.step)
stop = theano.tensor.constant(stop)
# If start or stop are None, make them a Generic constant
#in case of indexing using python int # Else, they should be converted to Tensor Variables of
#elif isinstance(ind,int): # dimension 1 and int/uint dtype.
# start = theano.tensor.constant(ind) if start is None:
# stop = start + 1 start = generic_None
#elif ind.ndim == 0: else:
# start = ind if not isinstance(start, gof.Variable):
# stop = ind + 1 start = tensor.as_tensor_variable(start)
if not (start.ndim == 0 and start.dtype in tensor.discrete_dtypes):
else: raise ValueError((
raise NotImplemented( "Impossible to index into a sparse matrix with "
"slice where start=%s" % start),
start.ndim, start.dtype)
if stop is None:
stop = generic_None
else:
if not isinstance(stop, gof.Variable):
stop = tensor.as_tensor_variable(stop)
if not (stop.ndim == 0 and stop.dtype in tensor.discrete_dtypes):
raise ValueError((
"Impossible to index into a sparse matrix with "
"slice where stop=%s" % stop),
stop.ndim, stop.dtype)
elif ((isinstance(ind, gof.Variable) and
getattr(ind, 'ndim', -1) == 0)
or numpy.isscalar(ind)):
raise NotImplementedError(
'Theano has no sparse vector' + 'Theano has no sparse vector' +
'Use X[a:b,c:d], X[a:b,c:c+1] or X[a:b] instead.') 'Use X[a:b,c:d], X[a:b,c:c+1] or X[a:b] instead.')
else:
raise ValueError((
'Advanced indexing is not implemented for sparse '
'matrices. Argument not supported: %s' % ind))
input_op += [start, stop] input_op += [start, stop]
if len(index) == 1: if len(index) == 1:
i = theano.gof.Constant(theano.gof.generic, None) input_op += [generic_None, generic_None]
input_op += [i, i]
return gof.Apply(self, input_op, [x.type()]) return gof.Apply(self, input_op, [x.type()])
...@@ -765,7 +785,7 @@ class GetItemScalar(gof.op.Op): ...@@ -765,7 +785,7 @@ class GetItemScalar(gof.op.Op):
def perform(self, node, (x, ind1, ind2), (out, )): def perform(self, node, (x, ind1, ind2), (out, )):
assert _is_sparse(x) assert _is_sparse(x)
out[0] = x[ind1, ind2] out[0] = theano._asarray(x[ind1, ind2], x.dtype)
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论