提交 b3d65298 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

fix infer_shapes for mpi ops

上级 b4cd9a55
......@@ -130,8 +130,8 @@ class MPIRecv(Op):
def __str__(self):
return "MPIRecv{source: %d, tag: %d, shape: %s, dtype: %s}"%self._info
#def infer_shape(self, node, shapes):
# return [self.shape]
def infer_shape(self, node, shapes):
return [None, self.shape]
class MPIRecvWait(Op):
"""
......@@ -168,8 +168,8 @@ class MPIRecvWait(Op):
def __str__(self):
return "MPIRecvWait"
# def infer_shape(self, node, shapes):
# return shapes
def infer_shape(self, node, shapes):
return [shapes[1]]
class MPISend(Op):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论