提交 bf10c56c authored 作者: David Warde-Farley's avatar David Warde-Farley

Revised tests for the Cholesky op.

上级 5ff2be4e
...@@ -11,6 +11,7 @@ from theano import config ...@@ -11,6 +11,7 @@ from theano import config
# The one in comment are not tested... # The one in comment are not tested...
from theano.sandbox.linalg.ops import (cholesky, from theano.sandbox.linalg.ops import (cholesky,
Cholesky, # op class
matrix_inverse, matrix_inverse,
#solve, #solve,
#diag, #diag,
...@@ -27,29 +28,36 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -27,29 +28,36 @@ from theano.sandbox.linalg.ops import (cholesky,
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
if 0: def check_lower_triangular(pd, ch_f):
def test_cholesky(): ch = ch_f(pd)
#TODO: test upper and lower triangular assert ch[0, pd.shape[1] - 1] == 0
#todo: unittest randomseed assert ch[pd.shape[0] - 1, 0] != 0
rng = numpy.random.RandomState(utt.fetch_seed()) assert numpy.allclose(numpy.dot(ch, ch.T), pd)
assert not numpy.allclose(numpy.dot(ch.T, ch), pd)
r = rng.randn(5,5)
pd = numpy.dot(r,r.T) def check_upper_triangular(pd, ch_f):
ch = ch_f(pd)
assert ch[4, 0] == 0
assert ch[0, 4] != 0
assert numpy.allclose(numpy.dot(ch.T, ch), pd)
assert not numpy.allclose(numpy.dot(ch, ch.T), pd)
x = tensor.matrix()
chol = cholesky(x)
f = function([x], tensor.dot(chol, chol.T)) # an optimization could remove this
ch_f = function([x], chol) def test_cholesky():
rng = numpy.random.RandomState(utt.fetch_seed())
# quick check that chol is upper-triangular r = rng.randn(5, 5)
ch = ch_f(pd) pd = numpy.dot(r,r.T)
print ch x = tensor.matrix()
assert ch[0,4] != 0 chol = cholesky(x)
assert ch[4,0] == 0 ch_f = function([x], chol)
assert numpy.allclose(numpy.dot(ch.T,ch),pd) yield check_lower_triangular, pd, ch_f
assert not numpy.allclose(numpy.dot(ch,ch.T),pd) chol = Cholesky(lower=True)(x)
ch_f = function([x], chol)
yield check_lower_triangular, pd, ch_f
chol = Cholesky(lower=False)(x)
ch_f = function([x], chol)
yield check_upper_triangular, pd, ch_f
def test_inverse_correctness(): def test_inverse_correctness():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论