提交 5e80c182 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add test for trace.

上级 ea99ec0b
...@@ -25,7 +25,7 @@ from theano.sandbox.linalg.ops import (cholesky, ...@@ -25,7 +25,7 @@ from theano.sandbox.linalg.ops import (cholesky,
#alloc_diag, #alloc_diag,
det, det,
#PSD_hint, #PSD_hint,
#trace, trace,
#spectral_radius_bound #spectral_radius_bound
) )
...@@ -121,3 +121,23 @@ def test_extract_diag(): ...@@ -121,3 +121,23 @@ def test_extract_diag():
ok = True ok = True
assert ok assert ok
# not testing the view=True case since it is not used anywhere. # not testing the view=True case since it is not used anywhere.
def test_trace():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix()
g = trace(x)
f = theano.function([x], g)
m = rng.rand(4, 4).astype(config.floatX)
v = numpy.trace(m)
assert v == f(m)
xx = theano.tensor.vector()
ok = False
try:
trace(xx)
except TypeError:
ok = True
assert ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论