提交 f07172d1 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add test for extract_diag.

上级 c0cfa4e3
...@@ -3,6 +3,9 @@ import numpy ...@@ -3,6 +3,9 @@ import numpy
import theano import theano
from theano import tensor, function from theano import tensor, function
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
from theano.tests import unittest_tools as utt
utt.seed_rng()
try: try:
import scipy import scipy
...@@ -16,9 +19,9 @@ except ImportError: ...@@ -16,9 +19,9 @@ except ImportError:
from theano.sandbox.linalg.ops import (cholesky, from theano.sandbox.linalg.ops import (cholesky,
matrix_inverse, matrix_inverse,
#solve, #solve,
#diag, diag,
#extract_diag, extract_diag,
#alloc_diag, alloc_diag,
det, det,
#PSD_hint, #PSD_hint,
#trace, #trace,
...@@ -88,3 +91,34 @@ def test_det_grad(): ...@@ -88,3 +91,34 @@ def test_det_grad():
r = rng.randn(5,5) r = rng.randn(5,5)
tensor.verify_grad(det, [r], rng=numpy.random) tensor.verify_grad(det, [r], rng=numpy.random)
def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.fmatrix()
g = extract_diag(x)
f = theano.function([x], g)
v = numpy.array([0.1, 0.2, 0.4]).astype('float32')
m = numpy.diag(v)
r = f(m)
# The right diagonal is extracted
assert (r == v).all()
# Make sure we don't have a view
assert r.base is None
m = rng.rand(2, 3).astype('float32')
ok = False
try:
r = f(m)
except Exception:
ok = True
assert ok
xx = theano.tensor.fvector()
ok = False
try:
g = extract_diag(xx)
except TypeError:
ok = True
assert ok
# not testing the view=True case since it is not used anywhere.
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论