提交 234f7460 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

MpiSend passes array to SendWait to avoid gc

上级 aadd0e99
......@@ -200,7 +200,8 @@ class MPISend(Op):
def make_node(self, data):
return gof.Apply(self, [data],
[theano.Variable(Generic())])
[theano.Variable(Generic()), data.type()])
view_map = {1: [0]}
def perform(self, node, inp, out):
......@@ -209,6 +210,7 @@ class MPISend(Op):
request = comm.Isend(data, self.dest, self.tag)
out[0][0] = request
out[1][0] = data
def __str__(self):
return "MPISend{dest: %d, tag: %d}" % self._info
......@@ -232,8 +234,8 @@ class MPISendWait(Op):
def __hash__(self):
return hash((type(self), self.tag))
def make_node(self, request):
return gof.Apply(self, [request],
def make_node(self, request, data):
return gof.Apply(self, [request, data],
[theano.Variable(Generic())])
def perform(self, node, inp, out):
......@@ -248,7 +250,7 @@ def isend(var, dest, tag):
return MPISend(dest, tag)(var)
def send(var, dest, tag):
return MPISendWait(tag)(isend(var, dest, tag))
return MPISendWait(tag)(*isend(var, dest, tag))
def irecv(shape, dtype, source, tag):
return MPIRecv(source, tag, shape, dtype)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论