提交 83f39106 authored 作者: David Warde-Farley's avatar David Warde-Farley

Deduplicate shape testing code.

Better living through generator tests.
上级 6873a47e
...@@ -76,31 +76,24 @@ def test_cholesky_grad(): ...@@ -76,31 +76,24 @@ def test_cholesky_grad():
yield utt.verify_grad, Cholesky(lower=False), [pd], 3, rng yield utt.verify_grad, Cholesky(lower=False), [pd], 3, rng
def test_cholesky_shape(): def test_cholesky_and_cholesky_grad_shape():
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
x = tensor.matrix() x = tensor.matrix()
l = cholesky(x) for l in (cholesky(x), Cholesky(lower=True)(x), Cholesky(lower=False)(x)):
f = theano.function([x], l.shape) f_chol = theano.function([x], l.shape)
topo = f.maker.env.toposort() g = tensor.grad(l.sum(), x)
if config.mode != 'FAST_COMPILE': f_cholgrad = theano.function([x], g.shape)
assert sum([node.op.__class__ == Cholesky for node in topo]) == 0 topo_chol = f_chol.maker.env.toposort()
for shp in [2, 3, 5]: topo_cholgrad = f_cholgrad.maker.env.toposort()
m = numpy.cov(rng.randn(shp, shp + 10)).astype(config.floatX) if config.mode != 'FAST_COMPILE':
assert numpy.all(f(m) == (shp, shp)) assert sum([node.op.__class__ == Cholesky
for node in topo_chol]) == 0
assert sum([node.op.__class__ == CholeskyGrad
def test_cholesky_grad_shape(): for node in topo_cholgrad]) == 0
rng = numpy.random.RandomState(utt.fetch_seed()) for shp in [2, 3, 5]:
x = tensor.matrix() m = numpy.cov(rng.randn(shp, shp + 10)).astype(config.floatX)
l = cholesky(x) yield numpy.testing.assert_equal, f_chol(m), (shp, shp)
g = tensor.grad(l.sum(), x) yield numpy.testing.assert_equal, f_cholgrad(m), (shp, shp)
f = theano.function([x], g.shape)
topo = f.maker.env.toposort()
if config.mode != 'FAST_COMPILE':
assert sum([node.op.__class__ == CholeskyGrad for node in topo]) == 0
for shp in [2, 3, 5]:
m = numpy.cov(rng.randn(shp, shp + 10)).astype(config.floatX)
assert numpy.all(f(m) == (shp, shp))
def test_inverse_correctness(): def test_inverse_correctness():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论