提交 ba438daf authored 作者: Frederic's avatar Frederic

Add doc and trace test for not square matrix.

上级 c787e96e
...@@ -395,6 +395,7 @@ solve = Solve() # general solve ...@@ -395,6 +395,7 @@ solve = Solve() # general solve
#TODO: Optimizations to replace multiplication by matrix inverse with solve() Op (still unwritten) #TODO: Optimizations to replace multiplication by matrix inverse with solve() Op (still unwritten)
class ExtractDiag(Op): class ExtractDiag(Op):
""" Return the diagonal of matrix """
def __init__(self, view=False): def __init__(self, view=False):
self.view = view self.view = view
if self.view: if self.view:
......
...@@ -137,7 +137,8 @@ def test_trace(): ...@@ -137,7 +137,8 @@ def test_trace():
g = trace(x) g = trace(x)
f = theano.function([x], g) f = theano.function([x], g)
m = rng.rand(4, 4).astype(config.floatX) for shp in [(2, 3), (3, 2), (3, 3)]:
m = rng.rand(*shp).astype(config.floatX)
v = numpy.trace(m) v = numpy.trace(m)
assert v == f(m) assert v == f(m)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论