提交 908c8d51 authored 作者: Frederic's avatar Frederic

Fix optimization warning due as we didn't returned the right dtype for empty matrix in ExtractDiag.

上级 9989c1e1
......@@ -700,7 +700,7 @@ class ExtractDiag(Op):
# zero-dimensional matrices ...
if x.shape[0] == 0 or x.shape[1] == 0:
z[0] = numpy.zeros(0)
z[0] = numpy.zeros(0, dtype=x.dtype)
return
if x.shape[0] < x.shape[1]:
......
......@@ -332,6 +332,7 @@ def test_diag():
assert ok
# not testing the view=True case since it is not used anywhere.
def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix()
......@@ -370,7 +371,10 @@ def test_extract_diag_grad():
tensor.verify_grad(extract_diag, [x], rng=rng)
# not testing the view=True case since it is not used anywhere.
def test_extract_diag_empty():
c = theano.tensor.constant(numpy.array([[], []], 'int32'))
extract_diag(c).eval()
def test_trace():
rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论