提交 cb90af99 authored 作者: Hengjean's avatar Hengjean

Added support for sparse-dense matrices comparison, tests and updated documentation.

上级 8a6bdc91
...@@ -153,6 +153,8 @@ List of Implemented Operations ...@@ -153,6 +153,8 @@ List of Implemented Operations
- Basic Arithmetic - Basic Arithmetic
- :class:`Neg <theano.sparse.basic.Neg>`. - :class:`Neg <theano.sparse.basic.Neg>`.
The grad implemented is regular. The grad implemented is regular.
- :class:`equal <theano.sparse.basic.equal>`.
- :class:`notEqual <theano.sparse.basic.notEqual>`.
- :func:`add <theano.sparse.basic.add>`. - :func:`add <theano.sparse.basic.add>`.
The grad implemented is regular. The grad implemented is regular.
- :func:`sub <theano.sparse.basic.sub>`. - :func:`sub <theano.sparse.basic.sub>`.
......
...@@ -2208,6 +2208,45 @@ class EqualSS(gof.op.Op): ...@@ -2208,6 +2208,45 @@ class EqualSS(gof.op.Op):
equal_s_s = EqualSS() equal_s_s = EqualSS()
class EqualSD(gof.op.Op):
"""
:param x:sparse matrix
:param y:dense matrix
:return: x==y
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = as_sparse_variable(x), tensor.as_tensor_variable(y)
assert y.type.ndim == 2
return gof.Apply(self,
[x, y],
[SparseType(dtype='uint8',
format=x.type.format).make_variable()])
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x)
assert x.shape == y.shape
assert _is_dense(y)
out[0] = (x == y).astype('uint8')
def infer_shape(self, node, ins_shapes):
return [ins_shapes[0]]
def __str__(self):
return self.__class__.__name__
equal_s_d = EqualSD()
def equal(x, y): def equal(x, y):
""" """
Add two matrices, the two of which are sparse. Add two matrices, the two of which are sparse.
...@@ -2236,9 +2275,9 @@ def equal(x, y): ...@@ -2236,9 +2275,9 @@ def equal(x, y):
if x_is_sparse_variable and y_is_sparse_variable: if x_is_sparse_variable and y_is_sparse_variable:
return equal_s_s(x, y) return equal_s_s(x, y)
elif x_is_sparse_variable and not y_is_sparse_variable: elif x_is_sparse_variable and not y_is_sparse_variable:
raise NotImplementedError() return equal_s_d(x, y)
elif y_is_sparse_variable and not x_is_sparse_variable: elif y_is_sparse_variable and not x_is_sparse_variable:
raise NotImplementedError() return equal_s_d(y, x)
else: else:
raise NotImplementedError() raise NotImplementedError()
...@@ -2283,6 +2322,45 @@ class NotEqualSS(gof.op.Op): ...@@ -2283,6 +2322,45 @@ class NotEqualSS(gof.op.Op):
not_equal_s_s = NotEqualSS() not_equal_s_s = NotEqualSS()
class NotEqualSD(gof.op.Op):
"""
:param x:sparse matrix
:param y:dense matrix
:return: x==y
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = as_sparse_variable(x), tensor.as_tensor_variable(y)
assert y.type.ndim == 2
return gof.Apply(self,
[x, y],
[SparseType(dtype='uint8',
format=x.type.format).make_variable()])
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x)
assert x.shape == y.shape
assert _is_dense(y)
out[0] = (x != y).astype('uint8')
def infer_shape(self, node, ins_shapes):
return [ins_shapes[0]]
def __str__(self):
return self.__class__.__name__
not_equal_s_d = NotEqualSD()
def notEqual(x, y): def notEqual(x, y):
""" """
Add two matrices, the two of which are sparse. Add two matrices, the two of which are sparse.
...@@ -2311,9 +2389,9 @@ def notEqual(x, y): ...@@ -2311,9 +2389,9 @@ def notEqual(x, y):
if x_is_sparse_variable and y_is_sparse_variable: if x_is_sparse_variable and y_is_sparse_variable:
return not_equal_s_s(x, y) return not_equal_s_s(x, y)
elif x_is_sparse_variable and not y_is_sparse_variable: elif x_is_sparse_variable and not y_is_sparse_variable:
raise NotImplementedError() return not_equal_s_d(x, y)
elif y_is_sparse_variable and not x_is_sparse_variable: elif y_is_sparse_variable and not x_is_sparse_variable:
raise NotImplementedError() return not_equal_s_d(y, x)
else: else:
raise NotImplementedError() raise NotImplementedError()
......
...@@ -651,6 +651,11 @@ class test_comparison(unittest.TestCase): ...@@ -651,6 +651,11 @@ class test_comparison(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
#took from tensor basic_test.py
def _rand_ranged(self, min, max, shape):
return numpy.asarray(numpy.random.rand(*shape) * (max - min) + min,
dtype=config.floatX)
def test_equalss_csr(self): def test_equalss_csr(self):
x = sparse.csr_matrix() x = sparse.csr_matrix()
y = sparse.csr_matrix() y = sparse.csr_matrix()
...@@ -703,6 +708,58 @@ class test_comparison(unittest.TestCase): ...@@ -703,6 +708,58 @@ class test_comparison(unittest.TestCase):
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data)) self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data))
def test_equalsd_csr(self):
x = sparse.csr_matrix()
y = theano.tensor.matrix()
equality = equal(x, y)
f = theano.function([x, y], equality)
m1 = sp.csr_matrix(random_lil((10, 40), config.floatX, 3))
m2 = self._rand_ranged(1000, -1000, [10, 40])
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data))
def test_equalsd_csc(self):
x = sparse.csc_matrix()
y = theano.tensor.matrix()
equality = equal(x, y)
f = theano.function([x, y], equality)
m1 = sp.csc_matrix(random_lil((10, 40), config.floatX, 3))
m2 = self._rand_ranged(1000, -1000, [10, 40])
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data))
def test_not_equalsd_csr(self):
x = sparse.csr_matrix()
y = theano.tensor.matrix()
unequality = notEqual(x, y)
f = theano.function([x, y], unequality)
m1 = sp.csr_matrix(random_lil((10, 40), config.floatX, 3))
m2 = self._rand_ranged(1000, -1000, [10, 40])
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data))
def test_not_equalsd_csc(self):
x = sparse.csc_matrix()
y = theano.tensor.matrix()
unequality = notEqual(x, y)
f = theano.function([x, y], unequality)
m1 = sp.csc_matrix(random_lil((10, 40), config.floatX, 3))
m2 = self._rand_ranged(1000, -1000, [10, 40])
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data))
class T_conversion(unittest.TestCase): class T_conversion(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论