提交 0877fabf authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1851 from Hengjean/sparseComparisonOp

Added comparisons operators for sparse matrices
...@@ -153,6 +153,12 @@ List of Implemented Operations ...@@ -153,6 +153,12 @@ List of Implemented Operations
- Basic Arithmetic - Basic Arithmetic
- :class:`Neg <theano.sparse.basic.Neg>`. - :class:`Neg <theano.sparse.basic.Neg>`.
The grad implemented is regular. The grad implemented is regular.
- :func:`eq <theano.sparse.basic.eq>`.
- :func:`neq <theano.sparse.basic.neq>`.
- :func:`gt <theano.sparse.basic.gt>`.
- :func:`ge <theano.sparse.basic.ge>`.
- :func:`lt <theano.sparse.basic.lt>`.
- :func:`le <theano.sparse.basic.le>`.
- :func:`add <theano.sparse.basic.add>`. - :func:`add <theano.sparse.basic.add>`.
The grad implemented is regular. The grad implemented is regular.
- :func:`sub <theano.sparse.basic.sub>`. - :func:`sub <theano.sparse.basic.sub>`.
......
...@@ -284,6 +284,20 @@ class _sparse_py_operators: ...@@ -284,6 +284,20 @@ class _sparse_py_operators:
def __rmul__(left, right): def __rmul__(left, right):
return mul(left, right) return mul(left, right)
# comparison operators
def __lt__(self, other):
return lt(self, other)
def __le__(self, other):
return le(self, other)
def __gt__(self, other):
return gt(self, other)
def __ge__(self, other):
return ge(self, other)
# extra pseudo-operator symbols # extra pseudo-operator symbols
def __dot__(left, right): def __dot__(left, right):
...@@ -337,7 +351,7 @@ class _sparse_py_operators: ...@@ -337,7 +351,7 @@ class _sparse_py_operators:
return ret return ret
class SparseVariable(gof.Variable, _sparse_py_operators): class SparseVariable(_sparse_py_operators, gof.Variable):
dtype = property(lambda self: self.type.dtype) dtype = property(lambda self: self.type.dtype)
format = property(lambda self: self.type.format) format = property(lambda self: self.type.format)
...@@ -2178,6 +2192,371 @@ def mul(x, y): ...@@ -2178,6 +2192,371 @@ def mul(x, y):
raise NotImplementedError() raise NotImplementedError()
class __ComparisonOpSS(gof.op.Op):
"""
Used as a superclass for all comparisons between
two sparses matrices
:param x:first compared sparse matrix
:param y:second compared sparse matrix
:return: Comparison(x,y)
"""
#Function to override
def comparison(self, x, y):
raise NotImplementedError()
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x = as_sparse_variable(x)
y = as_sparse_variable(y)
if x.type.format != y.type.format:
raise NotImplementedError()
return gof.Apply(self,
[x, y],
[SparseType(dtype='uint8',
format=x.type.format).make_variable()])
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x) and _is_sparse(y)
assert x.shape == y.shape
out[0] = self.comparison(x, y).astype('uint8')
def infer_shape(self, node, ins_shapes):
return [ins_shapes[0]]
def __str__(self):
return self.__class__.__name__
class __ComparisonOpSD(gof.op.Op):
"""
Used as a superclass for all comparisons between
sparse and dense matrix
:param x:sparse matrix
:param y:dense matrix
:return: Comparison(x,y)
"""
#Function to override
def comparison(self, x, y):
raise NotImplementedError()
def __eq__(self, other):
return (type(self) == type(other))
def __hash__(self):
return hash(type(self))
def make_node(self, x, y):
x, y = as_sparse_variable(x), tensor.as_tensor_variable(y)
assert y.type.ndim == 2
return gof.Apply(self,
[x, y],
[SparseType(dtype='uint8',
format=x.type.format).make_variable()])
def perform(self, node, (x, y), (out, )):
assert _is_sparse(x)
assert x.shape == y.shape
assert _is_dense(y)
out[0] = self.comparison(x, y).astype('uint8')
def infer_shape(self, node, ins_shapes):
return [ins_shapes[0]]
def __str__(self):
return self.__class__.__name__
def __ComparisonSwitch(SS, SD, DS):
"""
:param SS: function to apply between two sparses matrices.
:param SD: function to apply between a sparse and a dense matrix.
:param DS: function to apply between a dense and a sparse matrix.
:return: switch function taking two matrices as input
:note: At least one of `x` and `y` must be a sparse matrix.
:note: DS swap input as a dense matrix cannot be a left operand.
"""
def helper(x, y):
scipy_ver = [int(n) for n in scipy.__version__.split('.')[:2]]
assert scipy_ver >= [0, 13]
if hasattr(x, 'getnnz'):
x = as_sparse_variable(x)
if hasattr(y, 'getnnz'):
y = as_sparse_variable(y)
if not isinstance(x, theano.Variable):
x = theano.tensor.as_tensor_variable(x)
if not isinstance(y, theano.Variable):
y = theano.tensor.as_tensor_variable(y)
x_is_sparse_variable = _is_sparse_variable(x)
y_is_sparse_variable = _is_sparse_variable(y)
assert x_is_sparse_variable or y_is_sparse_variable
if x_is_sparse_variable and y_is_sparse_variable:
return SS(x, y)
elif x_is_sparse_variable and not y_is_sparse_variable:
return SD(x, y)
elif y_is_sparse_variable and not x_is_sparse_variable:
return DS(y, x)
else:
raise NotImplementedError()
return helper
class EqualSS(__ComparisonOpSS):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
:return: x==y
"""
def comparison(self, x, y):
return x == y
equal_s_s = EqualSS()
class EqualSD(__ComparisonOpSD):
"""
:param x:sparse matrix
:param y:dense matrix
:return: x==y
"""
def comparison(self, x, y):
return x == y
equal_s_d = EqualSD()
class NotEqualSS(__ComparisonOpSS):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
:return: x!=y
"""
def comparison(self, x, y):
return x != y
not_equal_s_s = NotEqualSS()
class NotEqualSD(__ComparisonOpSD):
"""
:param x:sparse matrix
:param y:dense matrix
:return: x!=y
"""
def comparison(self, x, y):
return x != y
not_equal_s_d = NotEqualSD()
class LessThanSS(__ComparisonOpSS):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
:return: x<y
"""
def comparison(self, x, y):
return x < y
less_than_s_s = LessThanSS()
class LessThanSD(__ComparisonOpSD):
"""
:param x:sparse matrix
:param y:dense matrix
:return: x<y
"""
def comparison(self, x, y):
return x < y
less_than_s_d = LessThanSD()
class GreaterThanSS(__ComparisonOpSS):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
:return: x>y
"""
def comparison(self, x, y):
return x > y
greater_than_s_s = GreaterThanSS()
class GreaterThanSD(__ComparisonOpSD):
"""
:param x:sparse matrix
:param y:dense matrix
:return: x>y
"""
def comparison(self, x, y):
return x > y
greater_than_s_d = GreaterThanSD()
class LessEqualSS(__ComparisonOpSS):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
:return: x<=y
"""
def comparison(self, x, y):
return x <= y
less_equal_s_s = LessEqualSS()
class LessEqualSD(__ComparisonOpSD):
"""
:param x:sparse matrix
:param y:dense matrix
:return: x<=y
"""
def comparison(self, x, y):
return x <= y
less_equal_s_d = LessEqualSD()
class GreaterEqualSS(__ComparisonOpSS):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
:return: x>=y
"""
def comparison(self, x, y):
return x >= y
greater_equal_s_s = GreaterEqualSS()
class GreaterEqualSD(__ComparisonOpSD):
"""
:param x:sparse matrix
:param y:dense matrix
:return: x>=y
"""
def comparison(self, x, y):
return x >= y
greater_equal_s_d = GreaterEqualSD()
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` == `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` != `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` < `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` > `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` <= `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` >= `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d,
less_equal_s_d)
class HStack(gof.op.Op): class HStack(gof.op.Op):
"""Stack sparse matrices horizontally (column wise). """Stack sparse matrices horizontally (column wise).
......
...@@ -40,7 +40,7 @@ from theano.sparse import ( ...@@ -40,7 +40,7 @@ from theano.sparse import (
Diag, diag, SquareDiagonal, square_diagonal, Diag, diag, SquareDiagonal, square_diagonal,
EnsureSortedIndices, ensure_sorted_indices, clean, EnsureSortedIndices, ensure_sorted_indices, clean,
ConstructSparseFromList, construct_sparse_from_list, ConstructSparseFromList, construct_sparse_from_list,
TrueDot, true_dot) TrueDot, true_dot, eq, neq, le, ge, gt, lt)
# Probability distributions are currently tested in test_sp2.py # Probability distributions are currently tested in test_sp2.py
#from theano.sparse import ( #from theano.sparse import (
...@@ -647,6 +647,329 @@ class T_AddMul(unittest.TestCase): ...@@ -647,6 +647,329 @@ class T_AddMul(unittest.TestCase):
verify_grad_sparse(op, [a, b], structured=False) verify_grad_sparse(op, [a, b], structured=False)
class test_comparison(unittest.TestCase):
def setUp(self):
utt.seed_rng()
#took from tensor basic_test.py
def _rand_ranged(self, min, max, shape):
return numpy.asarray(numpy.random.rand(*shape) * (max - min) + min,
dtype=config.floatX)
def __generalized_ss_test(self, theanop, symbolicType, testOp, scipyType):
scipy_ver = [int(n) for n in scipy.__version__.split('.')[:2]]
if (bool(scipy_ver < [0, 13])):
raise SkipTest("comparison operators need newer release of scipy")
x = symbolicType()
y = symbolicType()
op = theanop(x, y)
f = theano.function([x, y], op)
m1 = scipyType(random_lil((10, 40), config.floatX, 3))
m2 = scipyType(random_lil((10, 40), config.floatX, 3))
self.assertTrue(numpy.array_equal(f(m1, m2).data, testOp(m1, m2).data))
def __generalized_sd_test(self, theanop, symbolicType, testOp, scipyType):
scipy_ver = [int(n) for n in scipy.__version__.split('.')[:2]]
if (bool(scipy_ver < [0, 13])):
raise SkipTest("comparison operators need newer release of scipy")
x = symbolicType()
y = theano.tensor.matrix()
op = theanop(x, y)
f = theano.function([x, y], op)
m1 = scipyType(random_lil((10, 40), config.floatX, 3))
m2 = self._rand_ranged(1000, -1000, [10, 40])
self.assertTrue(numpy.array_equal(f(m1, m2).data, testOp(m1, m2).data))
def __generalized_ds_test(self, theanop, symbolicType, testOp, scipyType):
scipy_ver = [int(n) for n in scipy.__version__.split('.')[:2]]
if (bool(scipy_ver < [0, 13])):
raise SkipTest("comparison operators need newer release of scipy")
x = symbolicType()
y = theano.tensor.matrix()
op = theanop(y, x)
f = theano.function([y, x], op)
m1 = scipyType(random_lil((10, 40), config.floatX, 3))
m2 = self._rand_ranged(1000, -1000, [10, 40])
self.assertTrue(numpy.array_equal(f(m2, m1).data, testOp(m2, m1).data))
def test_equalss_csr(self):
self.__generalized_ss_test(eq, sparse.csr_matrix,
lambda x, y: x == y, sp.csr_matrix)
def test_equalss_csc(self):
self.__generalized_ss_test(eq, sparse.csc_matrix,
lambda x, y: x == y, sp.csc_matrix)
def test_not_equalss_csr(self):
self.__generalized_ss_test(neq, sparse.csr_matrix,
lambda x, y: x != y, sp.csr_matrix)
def test_not_equalss_csc(self):
self.__generalized_ss_test(neq, sparse.csc_matrix,
lambda x, y: x != y, sp.csc_matrix)
def test_less_equalss_csr(self):
opT = lambda x, y: x <= y
self.__generalized_ss_test(le, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_less_equalss_csc(self):
opT = lambda x, y: x <= y
self.__generalized_ss_test(le, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_less_thanss_csr(self):
opT = lambda x, y: x < y
self.__generalized_ss_test(opT, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_less_thanss_csc(self):
opT = lambda x, y: x < y
self.__generalized_ss_test(opT, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_greater_equalss_csr(self):
opT = lambda x, y: x >= y
self.__generalized_ss_test(opT, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_greater_equalss_csc(self):
opT = lambda x, y: x >= y
self.__generalized_ss_test(opT, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_greater_thanss_csr(self):
opT = lambda x, y: x > y
self.__generalized_ss_test(opT, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_greater_thanss_csc(self):
opT = lambda x, y: x > y
self.__generalized_ss_test(opT, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_equalsd_csr(self):
self.__generalized_sd_test(eq, sparse.csr_matrix,
lambda x, y: x == y, sp.csr_matrix)
def test_equalsd_csc(self):
self.__generalized_sd_test(eq, sparse.csc_matrix,
lambda x, y: x == y, sp.csc_matrix)
def test_not_equalsd_csr(self):
self.__generalized_sd_test(neq, sparse.csr_matrix,
lambda x, y: x != y, sp.csr_matrix)
def test_not_equalsd_csc(self):
self.__generalized_sd_test(neq, sparse.csc_matrix,
lambda x, y: x != y, sp.csc_matrix)
def test_less_equalsd_csr(self):
opT = lambda x, y: x <= y
self.__generalized_sd_test(le, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_less_equalsd_csc(self):
opT = lambda x, y: x <= y
self.__generalized_sd_test(le, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_less_thansd_csr(self):
opT = lambda x, y: x < y
self.__generalized_sd_test(opT, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_less_thansd_csc(self):
opT = lambda x, y: x < y
self.__generalized_sd_test(opT, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_greater_equalsd_csr(self):
opT = lambda x, y: x >= y
self.__generalized_sd_test(opT, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_greater_equalsd_csc(self):
opT = lambda x, y: x >= y
self.__generalized_sd_test(opT, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_greater_thansd_csr(self):
opT = lambda x, y: x > y
self.__generalized_sd_test(opT, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_greater_thansd_csc(self):
opT = lambda x, y: x > y
self.__generalized_sd_test(opT, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_equalds_csr(self):
self.__generalized_ds_test(eq, sparse.csr_matrix,
lambda x, y: x == y, sp.csr_matrix)
def test_equalds_csc(self):
self.__generalized_ds_test(eq, sparse.csc_matrix,
lambda x, y: x == y, sp.csc_matrix)
def test_not_equalds_csr(self):
self.__generalized_ds_test(neq, sparse.csr_matrix,
lambda x, y: x != y, sp.csr_matrix)
def test_not_equalds_csc(self):
self.__generalized_ds_test(neq, sparse.csc_matrix,
lambda x, y: x != y, sp.csc_matrix)
def test_less_equalds_csr(self):
opT = lambda x, y: x <= y
self.__generalized_ds_test(le, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_less_equalds_csc(self):
opT = lambda x, y: x <= y
self.__generalized_ds_test(le, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_less_thands_csr(self):
opT = lambda x, y: x < y
self.__generalized_ds_test(lt, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_less_thands_csc(self):
opT = lambda x, y: x < y
self.__generalized_ds_test(lt, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_greater_equalds_csr(self):
opT = lambda x, y: x >= y
self.__generalized_ds_test(ge, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_greater_equalds_csc(self):
opT = lambda x, y: x >= y
self.__generalized_ds_test(ge, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_greater_thands_csr(self):
opT = lambda x, y: x > y
self.__generalized_ds_test(gt, sparse.csr_matrix,
opT, sp.csr_matrix)
def test_greater_thands_csc(self):
opT = lambda x, y: x > y
self.__generalized_ds_test(gt, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_equality_case(self):
"""
Test assuring normal behaviour when values
in the matrices are equal
"""
scipy_ver = [int(n) for n in scipy.__version__.split('.')[:2]]
if (bool(scipy_ver < [0, 13])):
raise SkipTest("comparison operators need newer release of scipy")
x = sparse.csc_matrix()
y = theano.tensor.matrix()
m1 = sp.csc_matrix((2, 2), dtype=theano.config.floatX)
m2 = numpy.asarray([[0, 0], [0, 0]])
test = {gt: lambda x, y: x > y, lt: lambda x, y: x < y,
ge: lambda x, y: x >= y, le: lambda x, y: x <= y}
for func in test:
op = func(y, x)
f = theano.function([y, x], op)
self.assertTrue(numpy.array_equal(f(m2, m1),
test[func](m2, m1)))
class T_conversion(unittest.TestCase): class T_conversion(unittest.TestCase):
def setUp(self): def setUp(self):
utt.seed_rng() utt.seed_rng()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论