提交 6c9e2e80 authored 作者: Hengjean's avatar Hengjean

made fixes and tweaks

上级 16dc7e64
...@@ -2164,7 +2164,6 @@ def mul(x, y): ...@@ -2164,7 +2164,6 @@ def mul(x, y):
class __ComparisonOpSS(gof.op.Op): class __ComparisonOpSS(gof.op.Op):
""" """
Used as a superclass for all comparisons between Used as a superclass for all comparisons between
two sparses matrices two sparses matrices
...@@ -2177,7 +2176,7 @@ class __ComparisonOpSS(gof.op.Op): ...@@ -2177,7 +2176,7 @@ class __ComparisonOpSS(gof.op.Op):
#Function to override #Function to override
def comparison(self, x, y): def comparison(self, x, y):
return x raise NotImplementedError()
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
...@@ -2209,7 +2208,6 @@ class __ComparisonOpSS(gof.op.Op): ...@@ -2209,7 +2208,6 @@ class __ComparisonOpSS(gof.op.Op):
class __ComparisonOpSD(gof.op.Op): class __ComparisonOpSD(gof.op.Op):
""" """
Used as a superclass for all comparisons between Used as a superclass for all comparisons between
sparse and dense matrix sparse and dense matrix
...@@ -2222,7 +2220,7 @@ class __ComparisonOpSD(gof.op.Op): ...@@ -2222,7 +2220,7 @@ class __ComparisonOpSD(gof.op.Op):
#Function to override #Function to override
def comparison(self, x, y): def comparison(self, x, y):
return x raise NotImplementedError()
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other)) return (type(self) == type(other))
...@@ -2261,6 +2259,7 @@ def __ComparisonSwitch(SS, SD, DS): ...@@ -2261,6 +2259,7 @@ def __ComparisonSwitch(SS, SD, DS):
:return: switch function taking two matrices as input :return: switch function taking two matrices as input
:note: At least one of `x` and `y` must be a sparse matrix. :note: At least one of `x` and `y` must be a sparse matrix.
:note: DS swap input as a dense matrix cannot be a left operand.
""" """
def helper(x, y): def helper(x, y):
...@@ -2295,7 +2294,6 @@ def __ComparisonSwitch(SS, SD, DS): ...@@ -2295,7 +2294,6 @@ def __ComparisonSwitch(SS, SD, DS):
class EqualSS(__ComparisonOpSS): class EqualSS(__ComparisonOpSS):
""" """
:param x:first compared sparse matrix :param x:first compared sparse matrix
:param y:second compared sparse matrix :param y:second compared sparse matrix
...@@ -2311,7 +2309,6 @@ equal_s_s = EqualSS() ...@@ -2311,7 +2309,6 @@ equal_s_s = EqualSS()
class EqualSD(__ComparisonOpSD): class EqualSD(__ComparisonOpSD):
""" """
:param x:sparse matrix :param x:sparse matrix
:param y:dense matrix :param y:dense matrix
...@@ -2325,22 +2322,7 @@ class EqualSD(__ComparisonOpSD): ...@@ -2325,22 +2322,7 @@ class EqualSD(__ComparisonOpSD):
equal_s_d = EqualSD() equal_s_d = EqualSD()
def eq(x, y):
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` == `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fE = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)
return fE(x, y)
class NotEqualSS(__ComparisonOpSS): class NotEqualSS(__ComparisonOpSS):
""" """
:param x:first compared sparse matrix :param x:first compared sparse matrix
:param y:second compared sparse matrix :param y:second compared sparse matrix
...@@ -2355,7 +2337,6 @@ not_equal_s_s = NotEqualSS() ...@@ -2355,7 +2337,6 @@ not_equal_s_s = NotEqualSS()
class NotEqualSD(__ComparisonOpSD): class NotEqualSD(__ComparisonOpSD):
""" """
:param x:sparse matrix :param x:sparse matrix
:param y:dense matrix :param y:dense matrix
...@@ -2369,22 +2350,7 @@ class NotEqualSD(__ComparisonOpSD): ...@@ -2369,22 +2350,7 @@ class NotEqualSD(__ComparisonOpSD):
not_equal_s_d = NotEqualSD() not_equal_s_d = NotEqualSD()
def neq(x, y):
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` == `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fNE = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)
return fNE(x, y)
class LessThanSS(__ComparisonOpSS): class LessThanSS(__ComparisonOpSS):
""" """
:param x:first compared sparse matrix :param x:first compared sparse matrix
:param y:second compared sparse matrix :param y:second compared sparse matrix
...@@ -2399,7 +2365,6 @@ less_than_s_s = LessThanSS() ...@@ -2399,7 +2365,6 @@ less_than_s_s = LessThanSS()
class LessThanSD(__ComparisonOpSD): class LessThanSD(__ComparisonOpSD):
""" """
:param x:sparse matrix :param x:sparse matrix
:param y:dense matrix :param y:dense matrix
...@@ -2413,20 +2378,6 @@ class LessThanSD(__ComparisonOpSD): ...@@ -2413,20 +2378,6 @@ class LessThanSD(__ComparisonOpSD):
less_than_s_d = LessThanSD() less_than_s_d = LessThanSD()
def lt(x, y):
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` < `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fL = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)
return fL(x, y)
class GreaterThanSS(__ComparisonOpSS): class GreaterThanSS(__ComparisonOpSS):
""" """
...@@ -2457,27 +2408,13 @@ class GreaterThanSD(__ComparisonOpSD): ...@@ -2457,27 +2408,13 @@ class GreaterThanSD(__ComparisonOpSD):
greater_than_s_d = GreaterThanSD() greater_than_s_d = GreaterThanSD()
def gt(x, y):
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` > `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fG = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)
return fG(x, y)
class LessEqualSS(__ComparisonOpSS): class LessEqualSS(__ComparisonOpSS):
""" """
:param x:first compared sparse matrix :param x:first compared sparse matrix
:param y:second compared sparse matrix :param y:second compared sparse matrix
:return: x>y :return: x<=y
""" """
def comparison(self, x, y): def comparison(self, x, y):
...@@ -2492,7 +2429,7 @@ class LessEqualSD(__ComparisonOpSD): ...@@ -2492,7 +2429,7 @@ class LessEqualSD(__ComparisonOpSD):
:param x:sparse matrix :param x:sparse matrix
:param y:dense matrix :param y:dense matrix
:return: x>y :return: x<=y
""" """
def comparison(self, x, y): def comparison(self, x, y):
...@@ -2501,27 +2438,13 @@ class LessEqualSD(__ComparisonOpSD): ...@@ -2501,27 +2438,13 @@ class LessEqualSD(__ComparisonOpSD):
less_equal_s_d = LessEqualSD() less_equal_s_d = LessEqualSD()
def le(x, y):
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` >= `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fLE = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)
return fLE(x, y)
class GreaterEqualSS(__ComparisonOpSS): class GreaterEqualSS(__ComparisonOpSS):
""" """
:param x:first compared sparse matrix :param x:first compared sparse matrix
:param y:second compared sparse matrix :param y:second compared sparse matrix
:return: x>y :return: x>=y
""" """
def comparison(self, x, y): def comparison(self, x, y):
...@@ -2536,7 +2459,7 @@ class GreaterEqualSD(__ComparisonOpSD): ...@@ -2536,7 +2459,7 @@ class GreaterEqualSD(__ComparisonOpSD):
:param x:sparse matrix :param x:sparse matrix
:param y:dense matrix :param y:dense matrix
:return: x>y :return: x>=y
""" """
def comparison(self, x, y): def comparison(self, x, y):
...@@ -2544,20 +2467,71 @@ class GreaterEqualSD(__ComparisonOpSD): ...@@ -2544,20 +2467,71 @@ class GreaterEqualSD(__ComparisonOpSD):
greater_equal_s_d = GreaterEqualSD() greater_equal_s_d = GreaterEqualSD()
"""
:param x: A matrix variable.
:param y: A matrix variable.
def ge(x, y): :return: `x` == `y`
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` >= `y` :note: At least one of `x` and `y` must be a sparse matrix.
"""
eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fGE = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, """
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` != `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` < `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` > `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` <= `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` >= `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d,
less_equal_s_d) less_equal_s_d)
return fGE(x, y)
class HStack(gof.op.Op): class HStack(gof.op.Op):
......
...@@ -941,6 +941,34 @@ class test_comparison(unittest.TestCase): ...@@ -941,6 +941,34 @@ class test_comparison(unittest.TestCase):
self.__generalized_ds_test(gt, sparse.csc_matrix, self.__generalized_ds_test(gt, sparse.csc_matrix,
opT, sp.csc_matrix) opT, sp.csc_matrix)
def test_equality_case(self):
"""
Test assuring normal behaviour when values
in the matrices are equal
"""
scipy_ver = [int(n) for n in scipy.__version__.split('.')[:2]]
if (bool(scipy_ver < [0, 14])):
raise SkipTest("comparison operators need newer release of scipy")
x = sparse.csc_matrix()
y = theano.tensor.matrix()
m1 = sp.csc_matrix((2, 2), dtype=theano.config.floatX)
m2 = numpy.asarray([[0, 0], [0, 0]])
test = {gt: lambda x, y: x > y, lt: lambda x, y: x < y,
ge: lambda x, y: x >= y, le: lambda x, y: x <= y}
for func in test:
op = func(y, x)
f = theano.function([y, x], op)
self.assertTrue(numpy.array_equal(f(m2, m1),
test[func](m2, m1)))
class T_conversion(unittest.TestCase): class T_conversion(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论