提交 6c9e2e80 authored 作者: Hengjean's avatar Hengjean

made fixes and tweaks

上级 16dc7e64
......@@ -2164,7 +2164,6 @@ def mul(x, y):
class __ComparisonOpSS(gof.op.Op):
"""
Used as a superclass for all comparisons between
two sparses matrices
......@@ -2177,7 +2176,7 @@ class __ComparisonOpSS(gof.op.Op):
#Function to override
def comparison(self, x, y):
return x
raise NotImplementedError()
def __eq__(self, other):
return (type(self) == type(other))
......@@ -2209,7 +2208,6 @@ class __ComparisonOpSS(gof.op.Op):
class __ComparisonOpSD(gof.op.Op):
"""
Used as a superclass for all comparisons between
sparse and dense matrix
......@@ -2222,7 +2220,7 @@ class __ComparisonOpSD(gof.op.Op):
#Function to override
def comparison(self, x, y):
return x
raise NotImplementedError()
def __eq__(self, other):
return (type(self) == type(other))
......@@ -2261,6 +2259,7 @@ def __ComparisonSwitch(SS, SD, DS):
:return: switch function taking two matrices as input
:note: At least one of `x` and `y` must be a sparse matrix.
:note: DS swap input as a dense matrix cannot be a left operand.
"""
def helper(x, y):
......@@ -2295,7 +2294,6 @@ def __ComparisonSwitch(SS, SD, DS):
class EqualSS(__ComparisonOpSS):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
......@@ -2311,7 +2309,6 @@ equal_s_s = EqualSS()
class EqualSD(__ComparisonOpSD):
"""
:param x:sparse matrix
:param y:dense matrix
......@@ -2325,22 +2322,7 @@ class EqualSD(__ComparisonOpSD):
equal_s_d = EqualSD()
def eq(x, y):
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` == `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fE = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)
return fE(x, y)
class NotEqualSS(__ComparisonOpSS):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
......@@ -2355,7 +2337,6 @@ not_equal_s_s = NotEqualSS()
class NotEqualSD(__ComparisonOpSD):
"""
:param x:sparse matrix
:param y:dense matrix
......@@ -2369,22 +2350,7 @@ class NotEqualSD(__ComparisonOpSD):
not_equal_s_d = NotEqualSD()
def neq(x, y):
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` == `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fNE = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)
return fNE(x, y)
class LessThanSS(__ComparisonOpSS):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
......@@ -2399,7 +2365,6 @@ less_than_s_s = LessThanSS()
class LessThanSD(__ComparisonOpSD):
"""
:param x:sparse matrix
:param y:dense matrix
......@@ -2413,20 +2378,6 @@ class LessThanSD(__ComparisonOpSD):
less_than_s_d = LessThanSD()
def lt(x, y):
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` < `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fL = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)
return fL(x, y)
class GreaterThanSS(__ComparisonOpSS):
"""
......@@ -2457,27 +2408,13 @@ class GreaterThanSD(__ComparisonOpSD):
greater_than_s_d = GreaterThanSD()
def gt(x, y):
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` > `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fG = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)
return fG(x, y)
class LessEqualSS(__ComparisonOpSS):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
:return: x>y
:return: x<=y
"""
def comparison(self, x, y):
......@@ -2492,7 +2429,7 @@ class LessEqualSD(__ComparisonOpSD):
:param x:sparse matrix
:param y:dense matrix
:return: x>y
:return: x<=y
"""
def comparison(self, x, y):
......@@ -2501,27 +2438,13 @@ class LessEqualSD(__ComparisonOpSD):
less_equal_s_d = LessEqualSD()
def le(x, y):
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` >= `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fLE = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)
return fLE(x, y)
class GreaterEqualSS(__ComparisonOpSS):
"""
:param x:first compared sparse matrix
:param y:second compared sparse matrix
:return: x>y
:return: x>=y
"""
def comparison(self, x, y):
......@@ -2536,7 +2459,7 @@ class GreaterEqualSD(__ComparisonOpSD):
:param x:sparse matrix
:param y:dense matrix
:return: x>y
:return: x>=y
"""
def comparison(self, x, y):
......@@ -2544,20 +2467,71 @@ class GreaterEqualSD(__ComparisonOpSD):
greater_equal_s_d = GreaterEqualSD()
"""
:param x: A matrix variable.
:param y: A matrix variable.
def ge(x, y):
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` == `y`
:return: `x` >= `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
eq = __ComparisonSwitch(equal_s_s, equal_s_d, equal_s_d)
:note: At least one of `x` and `y` must be a sparse matrix.
"""
fGE = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d,
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` != `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
neq = __ComparisonSwitch(not_equal_s_s, not_equal_s_d, not_equal_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` < `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
lt = __ComparisonSwitch(less_than_s_s, less_than_s_d, greater_than_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` > `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
gt = __ComparisonSwitch(greater_than_s_s, greater_than_s_d, less_than_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` <= `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)
"""
:param x: A matrix variable.
:param y: A matrix variable.
:return: `x` >= `y`
:note: At least one of `x` and `y` must be a sparse matrix.
"""
ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d,
less_equal_s_d)
return fGE(x, y)
class HStack(gof.op.Op):
......
......@@ -941,6 +941,34 @@ class test_comparison(unittest.TestCase):
self.__generalized_ds_test(gt, sparse.csc_matrix,
opT, sp.csc_matrix)
def test_equality_case(self):
"""
Test assuring normal behaviour when values
in the matrices are equal
"""
scipy_ver = [int(n) for n in scipy.__version__.split('.')[:2]]
if (bool(scipy_ver < [0, 14])):
raise SkipTest("comparison operators need newer release of scipy")
x = sparse.csc_matrix()
y = theano.tensor.matrix()
m1 = sp.csc_matrix((2, 2), dtype=theano.config.floatX)
m2 = numpy.asarray([[0, 0], [0, 0]])
test = {gt: lambda x, y: x > y, lt: lambda x, y: x < y,
ge: lambda x, y: x >= y, le: lambda x, y: x <= y}
for func in test:
op = func(y, x)
f = theano.function([y, x], op)
self.assertTrue(numpy.array_equal(f(m2, m1),
test[func](m2, m1)))
class T_conversion(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论