提交 6b2b5132 authored 作者: Frederic Bastien's avatar Frederic Bastien

make Reshape.infer_shape work in more case.

上级 2da94091
...@@ -3083,7 +3083,9 @@ class Reshape(Op): ...@@ -3083,7 +3083,9 @@ class Reshape(Op):
def grad(self, (x, shp), (g_out,)): def grad(self, (x, shp), (g_out,)):
return [reshape(g_out, shape(x), ndim=x.ndim), None] return [reshape(g_out, shape(x), ndim=x.ndim), None]
def infer_shape(self, node, ishapes): def infer_shape(self, node, ishapes):
return [node.inputs[1]] #we can't just put node.inputs[1] as not all op support interation
#and this is needed in the ShapeOptimizer
return [tuple([node.inputs[1][i] for i in range(self.ndim)])]
def reshape(x, newshape, ndim=None, name=None): def reshape(x, newshape, ndim=None, name=None):
if ndim is None: if ndim is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论