提交 2a28db96 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard 提交者: Frederic

Add missing methods to AddSSData

上级 4171ff4c
...@@ -248,8 +248,19 @@ def hstack(blocks, format=None, dtype=None): ...@@ -248,8 +248,19 @@ def hstack(blocks, format=None, dtype=None):
class AddSSData(gof.op.Op): class AddSSData(gof.op.Op):
'''Add two sparse matrices assuming they have the same sparsity """Add two sparse matrices assuming they have the same sparsity
pattern. ''' pattern.
:Parameters:
- `x`: Sparse matrix.
- `y`: Sparse matrix.
:return: The sum of the two sparse matrix element wise.
:note: `x` and `y` are assumed to have the same sparsity pattern.
The grad implemented is structured.
"""
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
...@@ -270,8 +281,24 @@ class AddSSData(gof.op.Op): ...@@ -270,8 +281,24 @@ class AddSSData(gof.op.Op):
def perform(self, node, (x, y), (out, )): def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and _is_sparse(y) assert _is_sparse(x) and _is_sparse(y)
assert x.shape == y.shape assert x.shape == y.shape
assert x.data.shape == y.data.shape
out[0] = x.copy() out[0] = x.copy()
out[0].data += y.data out[0].data += y.data
def grad(self, inputs, (gz, )):
is_continuous = [(i.dtype in sparse.continuous_dtypes)
for i in inputs]
if all(is_continuous):
return [gz, gz]
else:
return [None] * len(inputs)
def infer_shape(self, node, ins_shapes):
return [ins_shapes[0]]
def __str__(self):
return self.__class__.__name__
add_s_s_data = AddSSData() add_s_s_data = AddSSData()
......
...@@ -202,6 +202,52 @@ def _hv_switch(op, expected_function): ...@@ -202,6 +202,52 @@ def _hv_switch(op, expected_function):
HStackTester = _hv_switch(S2.HStack, sp.hstack) HStackTester = _hv_switch(S2.HStack, sp.hstack)
VStackTester = _hv_switch(S2.VStack, sp.vstack) VStackTester = _hv_switch(S2.VStack, sp.vstack)
class AddSSDataTester(utt.InferShapeTester):
x = {}
a = {}
def setUp(self):
super(AddSSDataTester, self).setUp()
self.op_class = S2.AddSSData
for format in sparse.sparse_formats:
variable = getattr(theano.sparse, format + '_matrix')
rand = np.array(np.random.random_integers(3, size=(3, 4)) - 1,
dtype=theano.config.floatX)
constant = as_sparse_format(rand, format)
self.x[format] = [variable() for t in range(2)]
self.a[format] = [constant for t in range(2)]
def test_op(self):
for format in sparse.sparse_formats:
f = theano.function(
self.x[format],
S2.add_s_s_data(*self.x[format]))
tested = f(*self.a[format])
expected = 2 * self.a[format][0]
assert np.allclose(tested.toarray(), expected.toarray())
assert tested.format == expected.format
assert tested.dtype == expected.dtype
def test_infer_shape(self):
for format in sparse.sparse_formats:
self._compile_and_check(self.x[format],
[S2.add_s_s_data(*self.x[format])],
self.a[format],
self.op_class)
def test_grad(self):
for format in sparse.sparse_formats:
verify_grad_sparse(self.op_class(),
self.a[format],
structured=True)
class test_structured_add_s_v(unittest.TestCase): class test_structured_add_s_v(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论