提交 a2f0692e authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Add regular grad to sp_sum.

上级 59edf4f8
...@@ -1371,8 +1371,10 @@ class SpSum(gof.op.Op): ...@@ -1371,8 +1371,10 @@ class SpSum(gof.op.Op):
:note: :note:
- The grad implementation is controlled with the `sparse_grad` - The grad implementation is controlled with the `sparse_grad`
parameter. `True` will provide a structured grad and `False` parameter. `True` will provide a structured grad and `False`
will provide a regular grad. will provide a regular grad. For both choice, the grad
- This op does not return a sparse matrix. return a sparse matrix having the same format as `x`.
- This op does not return a sparse matrix, but a dense tensor
matrix.
""" """
def __init__(self, axis=None, sparse_grad=False): def __init__(self, axis=None, sparse_grad=False):
...@@ -1414,6 +1416,9 @@ class SpSum(gof.op.Op): ...@@ -1414,6 +1416,9 @@ class SpSum(gof.op.Op):
z[0] = numpy.asarray(x.sum(self.axis)).ravel() z[0] = numpy.asarray(x.sum(self.axis)).ravel()
def grad(self, (x,), (gz,)): def grad(self, (x,), (gz,)):
if x.dtype not in continuous_dtypes:
return [None]
if self.structured: if self.structured:
if self.axis is None: if self.axis is None:
r = gz * theano.sparse.sp_ones_like(x) r = gz * theano.sparse.sp_ones_like(x)
...@@ -1424,8 +1429,21 @@ class SpSum(gof.op.Op): ...@@ -1424,8 +1429,21 @@ class SpSum(gof.op.Op):
else: else:
raise ValueError('Illegal value for self.axis.') raise ValueError('Illegal value for self.axis.')
else: else:
# TODO o_format = x.format
raise NotImplementedError() x = dense_from_sparse(x)
if _is_sparse_variable(gz):
gz = dense_from_sparse(gz)
if self.axis is None:
r = tensor.second(x, gz)
else:
ones = tensor.ones_like(x)
if self.axis == 0:
r = tensor.addbroadcast(gz.dimshuffle('x', 0), 0) * ones
elif self.axis == 1:
r = tensor.addbroadcast(gz.dimshuffle(0, 'x'), 1) * ones
else:
raise ValueError('Illegal value for self.axis.')
r = SparseFromDense(o_format)(r)
return [r] return [r]
def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
......
...@@ -96,9 +96,6 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5): ...@@ -96,9 +96,6 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5):
assert len(shape) == 2 assert len(shape) == 2
assert out_dtype in sparse.all_dtypes assert out_dtype in sparse.all_dtypes
variable = [getattr(theano.sparse, format + '_matrix')(dtype=out_dtype)
for k in range(n)]
def _rand(): def _rand():
where = numpy.random.binomial(1, p, size=shape).astype('int8') where = numpy.random.binomial(1, p, size=shape).astype('int8')
...@@ -106,9 +103,10 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5): ...@@ -106,9 +103,10 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5):
value = numpy.random.randint(20, size=shape).astype(out_dtype) value = numpy.random.randint(20, size=shape).astype(out_dtype)
else: else:
value = numpy.random.random(shape) value = numpy.random.random(shape)
return where * value return where * value
variable = [getattr(theano.sparse, format + '_matrix')(dtype=out_dtype)
for k in range(n)]
data = [getattr(scipy.sparse, format + '_matrix')(_rand()) data = [getattr(scipy.sparse, format + '_matrix')(_rand())
for k in range(n)] for k in range(n)]
...@@ -1410,7 +1408,7 @@ class SpSumTester(utt.InferShapeTester): ...@@ -1410,7 +1408,7 @@ class SpSumTester(utt.InferShapeTester):
def test_grad(self): def test_grad(self):
for format in sparse.sparse_formats: for format in sparse.sparse_formats:
for axis in self.possible_axis: for axis in self.possible_axis:
for struct in [True]: for struct in [True, False]:
variable, data = sparse_random_inputs(format, variable, data = sparse_random_inputs(format,
shape=(10, 10)) shape=(10, 10))
verify_grad_sparse( verify_grad_sparse(
...@@ -1811,8 +1809,8 @@ def _hv_switch(op, expected_function): ...@@ -1811,8 +1809,8 @@ def _hv_switch(op, expected_function):
:Parameters: :Parameters:
- `op`: HStack or VStack class. - `op`: HStack or VStack class.
- `expected_function`: function from scipy for comparaison. - `expected_function`: function from scipy for comparaison.
""" """
class XStackTester(_HVStackTester): class XStackTester(_HVStackTester):
op_class = op op_class = op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论