提交 5e80c182 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add test for trace.

上级 ea99ec0b
......@@ -25,7 +25,7 @@ from theano.sandbox.linalg.ops import (cholesky,
#alloc_diag,
det,
#PSD_hint,
#trace,
trace,
#spectral_radius_bound
)
......@@ -121,3 +121,23 @@ def test_extract_diag():
ok = True
assert ok
# not testing the view=True case since it is not used anywhere.
def test_trace():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix()
g = trace(x)
f = theano.function([x], g)
m = rng.rand(4, 4).astype(config.floatX)
v = numpy.trace(m)
assert v == f(m)
xx = theano.tensor.vector()
ok = False
try:
trace(xx)
except TypeError:
ok = True
assert ok
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论