提交 77629a27 authored 作者: Hengjean's avatar Hengjean

Added scipy version check

上级 cb90af99
...@@ -2259,6 +2259,8 @@ def equal(x, y): ...@@ -2259,6 +2259,8 @@ def equal(x, y):
:note: At least one of `x` and `y` must be a sparse matrix. :note: At least one of `x` and `y` must be a sparse matrix.
""" """
assert scipy.__version__ >= '0.14.0'
if hasattr(x, 'getnnz'): if hasattr(x, 'getnnz'):
x = as_sparse_variable(x) x = as_sparse_variable(x)
if hasattr(y, 'getnnz'): if hasattr(y, 'getnnz'):
......
...@@ -657,6 +657,10 @@ class test_comparison(unittest.TestCase): ...@@ -657,6 +657,10 @@ class test_comparison(unittest.TestCase):
dtype=config.floatX) dtype=config.floatX)
def test_equalss_csr(self): def test_equalss_csr(self):
if (scipy.__version__ < '0.14.0'):
raise SkipTest("comparison operators need newer release of scipy")
x = sparse.csr_matrix() x = sparse.csr_matrix()
y = sparse.csr_matrix() y = sparse.csr_matrix()
...@@ -670,6 +674,10 @@ class test_comparison(unittest.TestCase): ...@@ -670,6 +674,10 @@ class test_comparison(unittest.TestCase):
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data)) self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data))
def test_equalss_csc(self): def test_equalss_csc(self):
if (scipy.__version__ < '0.14.0'):
raise SkipTest("comparison operators need newer release of scipy")
x = sparse.csc_matrix() x = sparse.csc_matrix()
y = sparse.csc_matrix() y = sparse.csc_matrix()
...@@ -683,6 +691,10 @@ class test_comparison(unittest.TestCase): ...@@ -683,6 +691,10 @@ class test_comparison(unittest.TestCase):
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data)) self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data))
def test_not_equalss_csr(self): def test_not_equalss_csr(self):
if (scipy.__version__ < '0.14.0'):
raise SkipTest("comparison operators need newer release of scipy")
x = sparse.csr_matrix() x = sparse.csr_matrix()
y = sparse.csr_matrix() y = sparse.csr_matrix()
...@@ -696,6 +708,10 @@ class test_comparison(unittest.TestCase): ...@@ -696,6 +708,10 @@ class test_comparison(unittest.TestCase):
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data)) self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data))
def test_not_equalss_csc(self): def test_not_equalss_csc(self):
if (scipy.__version__ < '0.14.0'):
raise SkipTest("comparison operators need newer release of scipy")
x = sparse.csc_matrix() x = sparse.csc_matrix()
y = sparse.csc_matrix() y = sparse.csc_matrix()
...@@ -709,6 +725,10 @@ class test_comparison(unittest.TestCase): ...@@ -709,6 +725,10 @@ class test_comparison(unittest.TestCase):
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data)) self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data))
def test_equalsd_csr(self): def test_equalsd_csr(self):
if (scipy.__version__ < '0.14.0'):
raise SkipTest("comparison operators need newer release of scipy")
x = sparse.csr_matrix() x = sparse.csr_matrix()
y = theano.tensor.matrix() y = theano.tensor.matrix()
...@@ -722,6 +742,10 @@ class test_comparison(unittest.TestCase): ...@@ -722,6 +742,10 @@ class test_comparison(unittest.TestCase):
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data)) self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data))
def test_equalsd_csc(self): def test_equalsd_csc(self):
if (scipy.__version__ < '0.14.0'):
raise SkipTest("comparison operators need newer release of scipy")
x = sparse.csc_matrix() x = sparse.csc_matrix()
y = theano.tensor.matrix() y = theano.tensor.matrix()
...@@ -735,6 +759,10 @@ class test_comparison(unittest.TestCase): ...@@ -735,6 +759,10 @@ class test_comparison(unittest.TestCase):
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data)) self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 == m2).data))
def test_not_equalsd_csr(self): def test_not_equalsd_csr(self):
if (scipy.__version__ < '0.14.0'):
raise SkipTest("comparison operators need newer release of scipy")
x = sparse.csr_matrix() x = sparse.csr_matrix()
y = theano.tensor.matrix() y = theano.tensor.matrix()
...@@ -748,6 +776,10 @@ class test_comparison(unittest.TestCase): ...@@ -748,6 +776,10 @@ class test_comparison(unittest.TestCase):
self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data)) self.assertTrue(numpy.array_equal(f(m1, m2).data, (m1 != m2).data))
def test_not_equalsd_csc(self): def test_not_equalsd_csc(self):
if (scipy.__version__ < '0.14.0'):
raise SkipTest("comparison operators need newer release of scipy")
x = sparse.csc_matrix() x = sparse.csc_matrix()
y = theano.tensor.matrix() y = theano.tensor.matrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论