提交 81b55123 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5318 from nouiz/local_reshape_dimshuffle

Fix opt crash
...@@ -6105,6 +6105,8 @@ class ExtractDiag(Op): ...@@ -6105,6 +6105,8 @@ class ExtractDiag(Op):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
z[0] = x.diagonal(self.offset, self.axis1, self.axis2) z[0] = x.diagonal(self.offset, self.axis1, self.axis2)
if not self.view:
z[0] = z[0].copy()
def grad(self, inputs, gout): def grad(self, inputs, gout):
(x,) = inputs (x,) = inputs
......
...@@ -133,7 +133,8 @@ def local_reshape_dimshuffle(node): ...@@ -133,7 +133,8 @@ def local_reshape_dimshuffle(node):
return False return False
else: else:
offset += 1 offset += 1
return [T.reshape(input_.owner.inputs[0], node.inputs[1])] return [T.reshape(input_.owner.inputs[0], node.inputs[1],
ndim=node.outputs[0].ndim)]
return False return False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论