提交 d7918e81 authored 作者: goodfeli's avatar goodfeli

Merge pull request #1099 from nouiz/fix_crash_extract_diag

Fix optimization warning due as we didn't returned the right dtype for e...
......@@ -129,7 +129,7 @@ shown.
3) The function :func:`theano.printing.pydotprint` will print a compiled theano function to a png file.
In the image, Apply nodes (the applications of ops) are shown as boxes and variables are shown as ovals.
In the image, Apply nodes (the applications of ops) are shown as ellipses and variables are shown as boxes.
The number at the end of each label indicates graph position.
Boxes and ovals have their own set of positions, so you can have apply #1 and also a
variable #1.
......
......@@ -703,7 +703,7 @@ class ExtractDiag(Op):
# zero-dimensional matrices ...
if x.shape[0] == 0 or x.shape[1] == 0:
z[0] = numpy.zeros(0)
z[0] = numpy.zeros(0, dtype=x.dtype)
return
if x.shape[0] < x.shape[1]:
......
......@@ -332,6 +332,7 @@ def test_diag():
assert ok
# not testing the view=True case since it is not used anywhere.
def test_extract_diag():
rng = numpy.random.RandomState(utt.fetch_seed())
x = theano.tensor.matrix()
......@@ -370,7 +371,10 @@ def test_extract_diag_grad():
tensor.verify_grad(extract_diag, [x], rng=rng)
# not testing the view=True case since it is not used anywhere.
def test_extract_diag_empty():
c = theano.tensor.constant(numpy.array([[], []], 'int32'))
extract_diag(c).eval()
def test_trace():
rng = numpy.random.RandomState(utt.fetch_seed())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论