提交 a2f0692e authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Add regular grad to sp_sum.

上级 59edf4f8
......@@ -1371,8 +1371,10 @@ class SpSum(gof.op.Op):
:note:
- The grad implementation is controlled with the `sparse_grad`
parameter. `True` will provide a structured grad and `False`
will provide a regular grad.
- This op does not return a sparse matrix.
will provide a regular grad. For both choice, the grad
return a sparse matrix having the same format as `x`.
- This op does not return a sparse matrix, but a dense tensor
matrix.
"""
def __init__(self, axis=None, sparse_grad=False):
......@@ -1414,6 +1416,9 @@ class SpSum(gof.op.Op):
z[0] = numpy.asarray(x.sum(self.axis)).ravel()
def grad(self, (x,), (gz,)):
if x.dtype not in continuous_dtypes:
return [None]
if self.structured:
if self.axis is None:
r = gz * theano.sparse.sp_ones_like(x)
......@@ -1424,8 +1429,21 @@ class SpSum(gof.op.Op):
else:
raise ValueError('Illegal value for self.axis.')
else:
# TODO
raise NotImplementedError()
o_format = x.format
x = dense_from_sparse(x)
if _is_sparse_variable(gz):
gz = dense_from_sparse(gz)
if self.axis is None:
r = tensor.second(x, gz)
else:
ones = tensor.ones_like(x)
if self.axis == 0:
r = tensor.addbroadcast(gz.dimshuffle('x', 0), 0) * ones
elif self.axis == 1:
r = tensor.addbroadcast(gz.dimshuffle(0, 'x'), 1) * ones
else:
raise ValueError('Illegal value for self.axis.')
r = SparseFromDense(o_format)(r)
return [r]
def infer_shape(self, node, shapes):
......
......@@ -96,9 +96,6 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5):
assert len(shape) == 2
assert out_dtype in sparse.all_dtypes
variable = [getattr(theano.sparse, format + '_matrix')(dtype=out_dtype)
for k in range(n)]
def _rand():
where = numpy.random.binomial(1, p, size=shape).astype('int8')
......@@ -106,9 +103,10 @@ def sparse_random_inputs(format, shape, n=1, out_dtype=None, p=0.5):
value = numpy.random.randint(20, size=shape).astype(out_dtype)
else:
value = numpy.random.random(shape)
return where * value
variable = [getattr(theano.sparse, format + '_matrix')(dtype=out_dtype)
for k in range(n)]
data = [getattr(scipy.sparse, format + '_matrix')(_rand())
for k in range(n)]
......@@ -1410,7 +1408,7 @@ class SpSumTester(utt.InferShapeTester):
def test_grad(self):
for format in sparse.sparse_formats:
for axis in self.possible_axis:
for struct in [True]:
for struct in [True, False]:
variable, data = sparse_random_inputs(format,
shape=(10, 10))
verify_grad_sparse(
......@@ -1811,8 +1809,8 @@ def _hv_switch(op, expected_function):
:Parameters:
- `op`: HStack or VStack class.
- `expected_function`: function from scipy for comparaison.
"""
class XStackTester(_HVStackTester):
op_class = op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论