提交 8a6bdc91 authored 作者: Hengjean's avatar Hengjean

Added equal and notEqual operation between sparse matrices

上级 daf7fc95
......@@ -284,6 +284,25 @@ class _sparse_py_operators:
def __rmul__(left, right):
return mul(left, right)
# comparison operators
def __lt__(self, other):
pass
def __le__(self, other):
pass
def __gt__(self, other):
pass
def __ge__(self, other):
pass
def __ne__(self, other):
pass
# extra pseudo-operator symbols
def __dot__(left, right):
......@@ -2149,6 +2168,156 @@ def mul(x, y):
raise NotImplementedError()
class EqualSS(gof.op.Op):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
:return: x==y
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x = as_sparse_variable(x)
y = as_sparse_variable(y)
if x.type.format != y.type.format:
raise NotImplementedError()
return gof.Apply(self,
[x, y],
[SparseType(dtype='uint8',
format=x.type.format).make_variable()])
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and _is_sparse(y)
assert x.shape == y.shape
out[0] = (x == y).astype('uint8')
def infer_shape(self, node, ins_shapes):
return [ins_shapes[0]]
def __str__(self):
return self.__class__.__name__
equal_s_s = EqualSS()
def equal(x, y):
"""
Add two matrices, the two of which are sparse.
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` == `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
if hasattr(x, 'getnnz'):
x = as_sparse_variable(x)
if hasattr(y, 'getnnz'):
y = as_sparse_variable(y)
if not isinstance(x, theano.Variable):
x = theano.tensor.as_tensor_variable(x)
if not isinstance(y, theano.Variable):
y = theano.tensor.as_tensor_variable(y)
x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y)
assert x_is_sparse_variable or y_is_sparse_variable
if x_is_sparse_variable and y_is_sparse_variable:
return equal_s_s(x, y)
elif x_is_sparse_variable and not y_is_sparse_variable:
raise NotImplementedError()
elif y_is_sparse_variable and not x_is_sparse_variable:
raise NotImplementedError()
else:
raise NotImplementedError()
class NotEqualSS(gof.op.Op):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
:return: x==y
"""
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x = as_sparse_variable(x)
y = as_sparse_variable(y)
if x.type.format != y.type.format:
raise NotImplementedError()
return gof.Apply(self,
[x, y],
[SparseType(dtype='uint8',
format=x.type.format).make_variable()])
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and _is_sparse(y)
assert x.shape == y.shape
out[0] = (x != y).astype('uint8')
def infer_shape(self, node, ins_shapes):
return [ins_shapes[0]]
def __str__(self):
return self.__class__.__name__
not_equal_s_s = NotEqualSS()
def notEqual(x, y):
"""
Add two matrices, the two of which are sparse.
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` == `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
if hasattr(x, 'getnnz'):
x = as_sparse_variable(x)
if hasattr(y, 'getnnz'):
y = as_sparse_variable(y)
if not isinstance(x, theano.Variable):
x = theano.tensor.as_tensor_variable(x)
if not isinstance(y, theano.Variable):
y = theano.tensor.as_tensor_variable(y)
x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y)
assert x_is_sparse_variable or y_is_sparse_variable
if x_is_sparse_variable and y_is_sparse_variable:
return not_equal_s_s(x, y)
elif x_is_sparse_variable and not y_is_sparse_variable:
raise NotImplementedError()
elif y_is_sparse_variable and not x_is_sparse_variable:
raise NotImplementedError()
else:
raise NotImplementedError()
class HStack(gof.op.Op):
"""Stack sparse matrices horizontally (column wise).
......
......@@ -40,7 +40,7 @@ from theano.sparse import (
Diag, diag, SquareDiagonal, square_diagonal,
EnsureSortedIndices, ensure_sorted_indices, clean,
ConstructSparseFromList, construct_sparse_from_list,
TrueDot, true_dot)
TrueDot, true_dot, equal, notEqual)
# Probability distributions are currently tested in test_sp2.py
#from theano.sparse import (
......@@ -647,6 +647,63 @@ class T_AddMul(unittest.TestCase):
verify_grad_sparse(op, [a, b], structured=False)
class test_comparison(unittest.TestCase):
def setUp(self):
utt.seed_rng()
def test_equalss_csr(self):
x = sparse.csr_matrix()
y = sparse.csr_matrix()
equality = equal(x, y)
f = theano.function([x, y], equality)
m1 = sp.csr_matrix(random_lil((10, 40), config.floatX, 3))
m2 = sp.csr_matrix(random_lil((10, 40), config.floatX, 3))
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data))
def test_equalss_csc(self):
x = sparse.csc_matrix()
y = sparse.csc_matrix()
equality = equal(x, y)
f = theano.function([x, y], equality)
m1 = sp.csc_matrix(random_lil((10, 40), config.floatX, 3))
m2 = sp.csc_matrix(random_lil((10, 40), config.floatX, 3))
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data))
def test_not_equalss_csr(self):
x = sparse.csr_matrix()
y = sparse.csr_matrix()
unequality = notEqual(x, y)
f = theano.function([x, y], unequality)
m1 = sp.csr_matrix(random_lil((10, 40), config.floatX, 3))
m2 = sp.csr_matrix(random_lil((10, 40), config.floatX, 3))
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data))
def test_not_equalss_csc(self):
x = sparse.csc_matrix()
y = sparse.csc_matrix()
unequality = notEqual(x, y)
f = theano.function([x, y], unequality)
m1 = sp.csc_matrix(random_lil((10, 40), config.floatX, 3))
m2 = sp.csc_matrix(random_lil((10, 40), config.floatX, 3))
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data))
class T_conversion(unittest.TestCase):
def setUp(self):
utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论