提交 0d89c1f6 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

infer_shape for ExtractDiag and associated test.

上级 f934d22e
......@@ -426,6 +426,10 @@ class ExtractDiag(Op):
def grad(self, inputs, g_outputs):
return [alloc_diag(g_outputs[0])]
def infer_shape(self, node, shapes):
x_s, = shapes
return [(x_s[0],)]
extract_diag = ExtractDiag()
#TODO: optimization to insert ExtractDiag with view=True
......
......@@ -21,6 +21,7 @@ from theano.sandbox.linalg.ops import (cholesky,
matrix_inverse,
#solve,
#diag,
ExtractDiag,
extract_diag,
#alloc_diag,
det,
......@@ -121,6 +122,13 @@ def test_extract_diag():
except TypeError:
ok = True
assert ok
f = theano.function([x], g.shape)
topo = f.maker.env.toposort()
assert sum([node.op.__class__ == ExtractDiag for node in topo]) == 0
m = rng.rand(3,3).astype(config.floatX)
assert f(m) == 3
# not testing the view=True case since it is not used anywhere.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论