提交 2288161b authored 作者: wonghang's avatar wonghang

Fix test_linalg.py for float32, also use it to compute log(det|A|) where A is positive-definite

上级 615c255c
...@@ -599,8 +599,7 @@ class TestMagma(unittest.TestCase): ...@@ -599,8 +599,7 @@ class TestMagma(unittest.TestCase):
for node in fn.maker.fgraph.toposort() for node in fn.maker.fgraph.toposort()
]) ])
# copied from theano/tensor/tests/test_slinalg.py # mostly copied from theano/tensor/tests/test_slinalg.py
def test_cholesky_grad(): def test_cholesky_grad():
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
r = rng.randn(5, 5).astype(config.floatX) r = rng.randn(5, 5).astype(config.floatX)
...@@ -630,16 +629,28 @@ def test_cholesky_grad_indef(): ...@@ -630,16 +629,28 @@ def test_cholesky_grad_indef():
# assert np.all(np.isnan(chol_f(matrix))) # assert np.all(np.isnan(chol_f(matrix)))
def test_lower_triangular_and_cholesky_grad(): def test_lower_triangular_and_cholesky_grad():
# Random lower triangular system is ill-conditioned.
#
# Reference
# -----------
# Viswanath, Divakar, and L. N. Trefethen. "Condition numbers of random triangular matrices."
# SIAM Journal on Matrix Analysis and Applications 19.2 (1998): 564-581.
#
# Use smaller number of N when using float32
if config.floatX == 'float64':
N = 100
else:
N = 5
rng = np.random.RandomState(utt.fetch_seed()) rng = np.random.RandomState(utt.fetch_seed())
r = rng.randn(10, 10).astype(config.floatX) r = rng.randn(N, N).astype(config.floatX)
y = rng.rand(10, 1).astype(config.floatX) y = rng.rand(N, 1).astype(config.floatX)
def f(r,y): def f(r,y):
PD = r.dot(r.T) PD = r.dot(r.T)
L = gpu_cholesky(PD) L = gpu_cholesky(PD)
A = gpu_solve_lower_triangular(L,y) A = gpu_solve_lower_triangular(L,y)
AAT = theano.tensor.dot(A,A.T) AAT = theano.tensor.dot(A,A.T)
B = AAT + theano.tensor.eye(10) B = AAT + theano.tensor.eye(N)
LB = gpu_cholesky(B) LB = gpu_cholesky(B)
return theano.tensor.sum(LB) return theano.tensor.sum(theano.tensor.log(theano.tensor.diag(LB)))
yield (lambda: utt.verify_grad(f, [r,y], 3, rng)) yield (lambda: utt.verify_grad(f, [r,y], 3, rng))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论