提交 b3d65298 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

fix infer_shapes for mpi ops

上级 b4cd9a55
...@@ -130,8 +130,8 @@ class MPIRecv(Op): ...@@ -130,8 +130,8 @@ class MPIRecv(Op):
def __str__(self): def __str__(self):
return "MPIRecv{source: %d, tag: %d, shape: %s, dtype: %s}"%self._info return "MPIRecv{source: %d, tag: %d, shape: %s, dtype: %s}"%self._info
#def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
# return [self.shape] return [None, self.shape]
class MPIRecvWait(Op): class MPIRecvWait(Op):
""" """
...@@ -168,8 +168,8 @@ class MPIRecvWait(Op): ...@@ -168,8 +168,8 @@ class MPIRecvWait(Op):
def __str__(self): def __str__(self):
return "MPIRecvWait" return "MPIRecvWait"
# def infer_shape(self, node, shapes): def infer_shape(self, node, shapes):
# return shapes return [shapes[1]]
class MPISend(Op): class MPISend(Op):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论