提交 ba438daf authored 作者: Frederic's avatar Frederic

Add doc and trace test for not square matrix.

上级 c787e96e
......@@ -395,6 +395,7 @@ solve = Solve() # general solve
#TODO: Optimizations to replace multiplication by matrix inverse with solve() Op (still unwritten)
class ExtractDiag(Op):
""" Return the diagonal of matrix """
def __init__(self, view=False):
self.view = view
if self.view:
......
......@@ -137,9 +137,10 @@ def test_trace():
g = trace(x)
f = theano.function([x], g)
m = rng.rand(4, 4).astype(config.floatX)
v = numpy.trace(m)
assert v == f(m)
for shp in [(2, 3), (3, 2), (3, 3)]:
m = rng.rand(*shp).astype(config.floatX)
v = numpy.trace(m)
assert v == f(m)
xx = theano.tensor.vector()
ok = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论