提交 b53610f8 authored 作者: Tanjay94's avatar Tanjay94

Fixed norm test for every mode and fixed test for old numpy version.

上级 8bae59f3
......@@ -905,9 +905,7 @@ eig = Eig()
class SVD(Op):
"""
Singular Value Decomposition.
Factors the matrix a as u * np.diag(s) * v, where u and v are unitary
and s is a 1-d array of a's singular values.
See doc in the docstring of the function just after this class.
"""
_numop = staticmethod(numpy.linalg.svd)
......@@ -1097,7 +1095,6 @@ def qr(a, mode="full"):
x = [[2, 1], [3, 4]]
if isinstance(numpy.linalg.qr(x,mode), tuple):
return QRFull(mode)(a)
else:
return QRIncomplete(mode)(a)
......
......@@ -176,60 +176,32 @@ def test_matrix_dot():
assert _allclose(numpy_sol, theano_sol)
def test_qr_default():
rng = numpy.random.RandomState(utt.fetch_seed())
A = tensor.matrix("A", dtype=theano.config.floatX)
Q = qr(A)
fn = function([A], Q)
a = rng.rand(4, 4).astype(theano.config.floatX)
n_q, n_r = numpy.linalg.qr(a)
t_q, t_r = fn(a)
assert _allclose(n_q, t_q)
assert _allclose(n_r, t_r)
def test_qr_modes():
rng = numpy.random.RandomState(utt.fetch_seed())
A = tensor.matrix("A", dtype=theano.config.floatX)
Q = qr(A, mode="reduced")
R = qr(A, mode="complete")
S = qr(A, mode="r")
T = qr(A, mode="raw")
U = qr(A, mode="full")
V = qr(A, mode="economic")
fq = function([A], Q)
fr = function([A], R)
fs = function([A], S)
ft = function([A], T)
fu = function([A], U)
fv = function([A], V)
A = tensor.matrix("A", dtype=theano.config.floatX)
a = rng.rand(4, 4).astype(theano.config.floatX)
n_q = numpy.linalg.qr(a, mode="reduced")
t_q = fq(a)
n_r = numpy.linalg.qr(a, mode="complete")
t_r = fr(a)
n_s = numpy.linalg.qr(a, mode="r")
t_s = fs(a)
n_t = numpy.linalg.qr(a, mode="raw")
t_t = ft(a)
n_u = numpy.linalg.qr(a, mode="full")
t_u = fu(a)
n_v = numpy.linalg.qr(a, mode="economic")
t_v = fv(a)
assert _allclose(n_q, t_q)
assert _allclose(n_r, t_r)
assert _allclose(n_s, t_s)
assert _allclose(n_u, t_u)
assert _allclose(n_v, t_v)
try:
numpy.linalg.qr(a, "complete")
except TypeError, e:
assert "name 'complete' is not defined" in str(e)
raise SkipTest
f = function([A], qr(A))
t_qr = f(a)
n_qr = numpy.linalg.qr(a)
assert _allclose(n_qr, t_qr)
for mode in ["reduced", "complete", "r", "raw", "full", "economic"]:
f = function([A], qr(A, mode))
t_qr = f(a)
n_qr = numpy.linalg.qr(a, mode)
if isinstance(n_qr, (list, tuple)):
assert _allclose(n_qr[0], t_qr[0])
assert _allclose(n_qr[1], t_qr[1])
else:
assert _allclose(n_qr, t_qr)
def test_svd():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论