提交 e70f9594 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard 提交者: Frederic

complete eliminate_zeros and added tests

上级 13de103e
......@@ -387,6 +387,10 @@ multinomial = Multinomial()
class EliminateZeros(gof.op.Op):
"""Eliminate zeros from the data of the matrix.
This wrap the method eliminate_zeros from scipy.
"""
def __eq__(self, other):
return (type(self) == type(other))
......@@ -401,6 +405,15 @@ class EliminateZeros(gof.op.Op):
assert _is_sparse(x)
out[0] = x.copy()
out[0].eliminate_zeros()
def grad(self, inputs, outputs_gradients):
return outputs_gradients
def infer_shape(self, node, ins_shapes):
return ins_shapes
def __str__(self):
return self.__class__.__name__
eliminate_zeros = EliminateZeros()
......
......@@ -21,6 +21,7 @@ from theano.sparse.sandbox import sp2 as S2
from theano.tests import unittest_tools as utt
from theano.sparse.basic import verify_grad_sparse
def as_sparse_format(data, format):
if format == 'csc':
return scipy.sparse.csc_matrix(data)
......@@ -112,6 +113,54 @@ class TestCast(utt.InferShapeTester):
verify_grad_sparse(S2.Cast('float64'), [a])
class EliminateZerosTester(utt.InferShapeTester):
indptr = np.array([0, 2, 3, 6])
indices = np.array([0, 2, 2, 0, 1, 2])
data = np.array([1, 0, 3, 0, 5, 6], dtype='float32')
properties = (data, indices, indptr)
x_csc = S.csc_matrix('csc', dtype='float32')
x_csr = S.csr_matrix('csr', dtype='float32')
def setUp(self):
super(EliminateZerosTester, self).setUp()
self.op_class = S2.EliminateZeros
def test_eliminate_zeros(self):
f_csc = theano.function([self.x_csc], S2.eliminate_zeros(self.x_csc))
f_csr = theano.function([self.x_csr], S2.eliminate_zeros(self.x_csr))
a = sp.csc_matrix(self.properties, dtype='float32')
b = a.copy()
b.eliminate_zeros()
assert np.all(f_csc(a).todense() == b.todense())
a = sp.csr_matrix(self.properties)
b = a.copy()
b.eliminate_zeros()
assert np.all(f_csr(a).todense() == b.todense())
def test_infer_shape(self):
a = sp.csc_matrix(self.properties, dtype='float32')
self._compile_and_check([self.x_csc],
[S2.eliminate_zeros(self.x_csc)],
[a],
self.op_class)
a = sp.csr_matrix(self.properties, dtype='float32')
self._compile_and_check([self.x_csr],
[S2.eliminate_zeros(self.x_csr)],
[a],
self.op_class)
def test_grad(self):
a = sp.csc_matrix(self.properties, dtype='float32')
verify_grad_sparse(S2.eliminate_zeros, [a])
a = sp.csr_matrix(self.properties, dtype='float32')
verify_grad_sparse(S2.eliminate_zeros, [a])
class test_structured_add_s_v(unittest.TestCase):
def setUp(self):
utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论