提交 625e75cd authored 作者: Alexander Matyasko's avatar Alexander Matyasko

Update tests for numpy cpu svd

上级 b3e5fa52
...@@ -21,7 +21,7 @@ from theano.tensor.nlinalg import ( ...@@ -21,7 +21,7 @@ from theano.tensor.nlinalg import (
AllocDiag, alloc_diag, ExtractDiag, extract_diag, diag, AllocDiag, alloc_diag, ExtractDiag, extract_diag, diag,
trace, Det, det, Eig, eig, Eigh, EighGrad, eigh, trace, Det, det, Eig, eig, Eigh, EighGrad, eigh,
matrix_dot, _zero_disconnected, qr, matrix_power, matrix_dot, _zero_disconnected, qr, matrix_power,
norm, svd, TensorInv, tensorinv, tensorsolve) norm, svd, SVD, TensorInv, tensorinv, tensorsolve)
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
...@@ -129,45 +129,47 @@ def test_qr_modes(): ...@@ -129,45 +129,47 @@ def test_qr_modes():
assert "name 'complete' is not defined" in str(e) assert "name 'complete' is not defined" in str(e)
def test_svd(): class test_SVD(utt.InferShapeTester):
rng = np.random.RandomState(utt.fetch_seed()) op_class = SVD
A = tensor.matrix("A", dtype=theano.config.floatX) dtype = 'float32'
U, S, VT = svd(A)
fn = function([A], [U, S, VT])
a = rng.rand(4, 4).astype(theano.config.floatX)
n_u, n_s, n_vt = np.linalg.svd(a)
t_u, t_s, t_vt = fn(a)
assert _allclose(n_u, t_u)
assert _allclose(n_s, t_s)
assert _allclose(n_vt, t_vt)
fn = function([A], svd(A, compute_uv=False))
t_s = fn(a)
assert _allclose(n_s, t_s)
def test_svd_infer_shape(): def setUp(self):
rng = np.random.RandomState(utt.fetch_seed()) super(test_SVD, self).setUp()
A = tensor.matrix("A", dtype=theano.config.floatX) self.rng = np.random.RandomState(utt.fetch_seed())
self.A = theano.tensor.matrix(dtype=self.dtype)
self.op = svd
for shp, full_matrices in itertools.product([(4, 4), (2, 4), (4, 2)], def test_svd(self):
[True, False]): rng = np.random.RandomState(utt.fetch_seed())
U, S, VT = svd(A, full_matrices=full_matrices) A = tensor.matrix("A", dtype=self.dtype)
U, S, VT = svd(A)
fn = function([A], [U, S, VT]) fn = function([A], [U, S, VT])
fn_shp = function([A], [U.shape, S.shape, VT.shape]) a = rng.rand(4, 4).astype(self.dtype)
a = rng.rand(*shp).astype(theano.config.floatX) n_u, n_s, n_vt = np.linalg.svd(a)
t_u, t_s, t_vt = fn(a) t_u, t_s, t_vt = fn(a)
t_u_shp, t_s_shp, t_vt_shp = fn_shp(a)
assert _allclose(t_u.shape, t_u_shp) assert _allclose(n_u, t_u)
assert _allclose(t_s.shape, t_s_shp) assert _allclose(n_s, t_s)
assert _allclose(t_vt.shape, t_vt_shp) assert _allclose(n_vt, t_vt)
fn = function([A], svd(A, compute_uv=False))
t_s = fn(a)
assert _allclose(n_s, t_s)
fn = function([A], svd(A, compute_uv=False)) def test_svd_infer_shape(self):
fn_shp = function([A], svd(A, compute_uv=False).shape) self.validate_shape((4, 4), full_matrices=True, compute_uv=True)
a = rng.rand(4, 2).astype(theano.config.floatX) self.validate_shape((4, 4), full_matrices=False, compute_uv=True)
assert _allclose(fn(a).shape, fn_shp(a)) self.validate_shape((2, 4), full_matrices=False, compute_uv=True)
self.validate_shape((4, 2), full_matrices=False, compute_uv=True)
self.validate_shape((4, 4), compute_uv=False)
def validate_shape(self, shape, compute_uv=True, full_matrices=True):
A = self.A
A_v = self.rng.rand(*shape).astype(self.dtype)
outputs = self.op(A, full_matrices=full_matrices, compute_uv=compute_uv)
if not compute_uv:
outputs = [outputs]
self._compile_and_check([A], outputs, [A_v], self.op_class, warn=False)
def test_tensorsolve(): def test_tensorsolve():
...@@ -431,8 +433,7 @@ class test_Eig(utt.InferShapeTester): ...@@ -431,8 +433,7 @@ class test_Eig(utt.InferShapeTester):
super(test_Eig, self).setUp() super(test_Eig, self).setUp()
self.rng = np.random.RandomState(utt.fetch_seed()) self.rng = np.random.RandomState(utt.fetch_seed())
self.A = theano.tensor.matrix(dtype=self.dtype) self.A = theano.tensor.matrix(dtype=self.dtype)
self.X = np.asarray(self.rng.rand(5, 5), self.X = np.asarray(self.rng.rand(5, 5), dtype=self.dtype)
dtype=self.dtype)
self.S = self.X.dot(self.X.T) self.S = self.X.dot(self.X.T)
def test_infer_shape(self): def test_infer_shape(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论