提交 e7204ea2 authored 作者: Tanjay94's avatar Tanjay94

Fixed norm test to include outdated numpy version.

上级 0c737bad
......@@ -4,6 +4,7 @@ import numpy
import numpy.linalg
from numpy.testing import assert_array_almost_equal
from numpy.testing import dec, assert_array_equal, assert_allclose
from nose.plugins.skip import SkipTest
import theano
from theano import tensor, function
......@@ -636,20 +637,42 @@ class Matrix_power():
self.assertRaises(ValueError, f, a)
class T_NormTests(unittest.TestCase):
try:
def test_wrong_type_of_ord_for_vector(self):
self.assertRaises(ValueError, norm, [2,1],'fro',0)
except TypeError:
raise SkipTest('Your numpy version is outdated.')
try:
def test_wrong_type_of_ord_for_vector_in_matrix(self):
self.assertRaises(ValueError, norm, [[2,1],[3,4]],'fro',0)
except TypeError:
raise SkipTest('Your numpy version is outdated.')
try:
def test_wrong_type_of_ord_for_vector_in_tensor(self):
self.assertRaises(ValueError, norm, [[[2,1],[3,4]],[[6,5],[7,8]]],'fro',0)
except TypeError:
raise SkipTest('Your numpy version is outdated.')
try:
def test_wrong_type_of_ord_for_matrix(self):
self.assertRaises(ValueError, norm, [[2,1],[3,4]],0,None)
except TypeError:
raise SkipTest('Your numpy version is outdated.')
try:
def test_wrong_type_of_ord_for_matrix_in_tensor(self):
self.assertRaises(ValueError, norm, [[[2,1],[3,4]],[[6,5],[7,8]]],0,None)
except TypeError:
raise SkipTest('Your numpy version is outdated.')
try:
def test_non_tensorial_input(self):
self.assertRaises(ValueError, norm, 3, None, None)
except TypeError:
raise SkipTest('Your numpy version is outdated.')
try:
def test_no_enough_dimensions(self):
self.assertRaises(ValueError, norm, [[2,1],[3,4]], None, 3)
except TypeError:
raise SkipTest('Your numpy version is outdated.')
try:
def test_numpy_compare(self):
rng = numpy.random.RandomState(utt.fetch_seed())
A = tensor.matrix("A", dtype=theano.config.floatX)
......@@ -660,6 +683,8 @@ class T_NormTests(unittest.TestCase):
n_n = numpy.linalg.norm(a, None, None)
t_n = fn(a)
assert _allclose(n_n, t_n)
except TypeError:
raise SkipTest('Your numpy version is outdated.')
class T_lstsq(unittest.TestCase):
def test_correct_solution(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论