提交 6459b949 authored 作者: Tanjay94's avatar Tanjay94

Fixed norm test for outdated numpy.

上级 af7ea687
...@@ -4,7 +4,6 @@ import numpy ...@@ -4,7 +4,6 @@ import numpy
import numpy.linalg import numpy.linalg
from numpy.testing import assert_array_almost_equal from numpy.testing import assert_array_almost_equal
from numpy.testing import dec, assert_array_equal, assert_allclose from numpy.testing import dec, assert_array_equal, assert_allclose
from nose.plugins.skip import SkipTest
import theano import theano
from theano import tensor, function from theano import tensor, function
...@@ -39,7 +38,7 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -39,7 +38,7 @@ from theano.sandbox.linalg.ops import (cholesky,
from theano.sandbox.linalg import eig, eigh, eigvalsh from theano.sandbox.linalg import eig, eigh, eigvalsh
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
from nose.plugins.skip import SkipTest
def check_lower_triangular(pd, ch_f): def check_lower_triangular(pd, ch_f):
ch = ch_f(pd) ch = ch_f(pd)
...@@ -651,8 +650,8 @@ class T_NormTests(unittest.TestCase): ...@@ -651,8 +650,8 @@ class T_NormTests(unittest.TestCase):
self.assertRaises(ValueError, norm, 3, None, None) self.assertRaises(ValueError, norm, 3, None, None)
def test_no_enough_dimensions(self): def test_no_enough_dimensions(self):
self.assertRaises(ValueError, norm, [[2,1],[3,4]], None, 3) self.assertRaises(ValueError, norm, [[2,1],[3,4]], None, 3)
try: def test_numpy_compare(self):
def test_numpy_compare(self): try:
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
A = tensor.matrix("A", dtype=theano.config.floatX) A = tensor.matrix("A", dtype=theano.config.floatX)
Q = norm(A, None, None) Q = norm(A, None, None)
...@@ -662,8 +661,8 @@ class T_NormTests(unittest.TestCase): ...@@ -662,8 +661,8 @@ class T_NormTests(unittest.TestCase):
n_n = numpy.linalg.norm(a, None, None) n_n = numpy.linalg.norm(a, None, None)
t_n = fn(a) t_n = fn(a)
assert _allclose(n_n, t_n) assert _allclose(n_n, t_n)
except TypeError: except TypeError:
raise SkipTest('Your numpy version is outdated.') raise SkipTest('Your numpy version is outdated.')
class T_lstsq(unittest.TestCase): class T_lstsq(unittest.TestCase):
def test_correct_solution(self): def test_correct_solution(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论