提交 488df15e authored 作者: Frederic Bastien's avatar Frederic Bastien

pep8

上级 c0291c58
......@@ -52,6 +52,7 @@ class LoadFromDisk(Op):
def __str__(self):
return "Load{dtype: %s, broadcastable: %s, mmep: %s}" % self._info
def load(path, dtype, broadcastable, mmap_mode=None):
"""
Load an array from an .npy file.
......@@ -91,6 +92,7 @@ else:
comm = MPI.COMM_WORLD
mpi_enabled = True
class MPIRecv(Op):
"""
An operation to asynchronously receive an array to a remote host using MPI
......@@ -120,6 +122,7 @@ class MPIRecv(Op):
return gof.Apply(self, [], [theano.Variable(Generic()),
tensor(self.dtype,
broadcastable=self.broadcastable)])
def perform(self, node, inp, out):
data = numpy.zeros(self.shape, dtype=self.dtype)
......@@ -137,6 +140,7 @@ class MPIRecv(Op):
def do_constant_folding(self, node):
return False
class MPIRecvWait(Op):
"""
An operation to wait on a previously received array using MPI
......@@ -160,6 +164,7 @@ class MPIRecvWait(Op):
return gof.Apply(self, [request, data],
[tensor(data.dtype,
broadcastable=data.broadcastable)])
def perform(self, node, inp, out):
request = inp[0]
......@@ -177,6 +182,7 @@ class MPIRecvWait(Op):
view_map = {0: [1]}
class MPISend(Op):
"""
An operation to asynchronously Send an array to a remote host using MPI
......@@ -216,6 +222,7 @@ class MPISend(Op):
def __str__(self):
return "MPISend{dest: %d, tag: %d}" % self._info
class MPISendWait(Op):
"""
An operation to wait on a previously sent array using MPI
......@@ -247,18 +254,23 @@ class MPISendWait(Op):
def __str__(self):
return "MPISendWait"
def isend(var, dest, tag):
return MPISend(dest, tag)(var)
def send(var, dest, tag):
return MPISendWait(tag)(*isend(var, dest, tag))
def irecv(shape, dtype, source, tag):
return MPIRecv(source, tag, shape, dtype)()
def recv(shape, dtype, source, tag):
return MPIRecvWait(tag)(*irecv(shape, dtype, source, tag))
# Ordering keys for scheduling
def mpi_send_wait_key(a):
""" Wait as long as possible on Waits, Start Send/Recvs early """
......@@ -268,6 +280,7 @@ def mpi_send_wait_key(a):
return -1
return 0
def mpi_tag_key(a):
""" Break MPI ties by using the variable tag - prefer lower tags first """
if isinstance(a.op, (MPISend, MPIRecv, MPIRecvWait, MPISendWait)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论