提交 22623909 authored 作者: lamblin's avatar lamblin

Merge pull request #1449 from nouiz/mpi

Add documentaiton about the MPI and load ops
...@@ -24,3 +24,4 @@ They are grouped into the following sections: ...@@ -24,3 +24,4 @@ They are grouped into the following sections:
signal/index signal/index
utils utils
extra_ops extra_ops
io
===================================================================
:mod:`tensor.io` -- Tensor IO Ops
===================================================================
.. module:: tensor.io
:platform: Unix, Windows
:synopsis: Tensor IO Ops
.. moduleauthor:: LISA
File operation
==============
- Load from disk with the function :func:`load <theano.tensor.io.load>` and its associated op :class:`LoadFromDisk <theano.tensor.io.LoadFromDisk>`
MPI operation
=============
- Non-blocking transfer: :func:`isend <theano.tensor.io.isend>` and :func:`irecv <theano.tensor.io.irecv>`.
- Blocking transfer: :func:`send <theano.tensor.io.send>` and :func:`recv <theano.tensor.io.recv>`
Details
=======
.. automodule:: theano.tensor.io
:members:
...@@ -52,6 +52,7 @@ class LoadFromDisk(Op): ...@@ -52,6 +52,7 @@ class LoadFromDisk(Op):
def __str__(self): def __str__(self):
return "Load{dtype: %s, broadcastable: %s, mmep: %s}" % self._info return "Load{dtype: %s, broadcastable: %s, mmep: %s}" % self._info
def load(path, dtype, broadcastable, mmap_mode=None): def load(path, dtype, broadcastable, mmap_mode=None):
""" """
Load an array from an .npy file. Load an array from an .npy file.
...@@ -91,6 +92,7 @@ else: ...@@ -91,6 +92,7 @@ else:
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
mpi_enabled = True mpi_enabled = True
class MPIRecv(Op): class MPIRecv(Op):
""" """
An operation to asynchronously receive an array to a remote host using MPI An operation to asynchronously receive an array to a remote host using MPI
...@@ -120,6 +122,7 @@ class MPIRecv(Op): ...@@ -120,6 +122,7 @@ class MPIRecv(Op):
return gof.Apply(self, [], [theano.Variable(Generic()), return gof.Apply(self, [], [theano.Variable(Generic()),
tensor(self.dtype, tensor(self.dtype,
broadcastable=self.broadcastable)]) broadcastable=self.broadcastable)])
def perform(self, node, inp, out): def perform(self, node, inp, out):
data = numpy.zeros(self.shape, dtype=self.dtype) data = numpy.zeros(self.shape, dtype=self.dtype)
...@@ -137,6 +140,7 @@ class MPIRecv(Op): ...@@ -137,6 +140,7 @@ class MPIRecv(Op):
def do_constant_folding(self, node): def do_constant_folding(self, node):
return False return False
class MPIRecvWait(Op): class MPIRecvWait(Op):
""" """
An operation to wait on a previously received array using MPI An operation to wait on a previously received array using MPI
...@@ -160,6 +164,7 @@ class MPIRecvWait(Op): ...@@ -160,6 +164,7 @@ class MPIRecvWait(Op):
return gof.Apply(self, [request, data], return gof.Apply(self, [request, data],
[tensor(data.dtype, [tensor(data.dtype,
broadcastable=data.broadcastable)]) broadcastable=data.broadcastable)])
def perform(self, node, inp, out): def perform(self, node, inp, out):
request = inp[0] request = inp[0]
...@@ -177,6 +182,7 @@ class MPIRecvWait(Op): ...@@ -177,6 +182,7 @@ class MPIRecvWait(Op):
view_map = {0: [1]} view_map = {0: [1]}
class MPISend(Op): class MPISend(Op):
""" """
An operation to asynchronously Send an array to a remote host using MPI An operation to asynchronously Send an array to a remote host using MPI
...@@ -216,6 +222,7 @@ class MPISend(Op): ...@@ -216,6 +222,7 @@ class MPISend(Op):
def __str__(self): def __str__(self):
return "MPISend{dest: %d, tag: %d}" % self._info return "MPISend{dest: %d, tag: %d}" % self._info
class MPISendWait(Op): class MPISendWait(Op):
""" """
An operation to wait on a previously sent array using MPI An operation to wait on a previously sent array using MPI
...@@ -247,18 +254,35 @@ class MPISendWait(Op): ...@@ -247,18 +254,35 @@ class MPISendWait(Op):
def __str__(self): def __str__(self):
return "MPISendWait" return "MPISendWait"
def isend(var, dest, tag): def isend(var, dest, tag):
"""
Non blocking send
"""
return MPISend(dest, tag)(var) return MPISend(dest, tag)(var)
def send(var, dest, tag): def send(var, dest, tag):
"""
blocking send
"""
return MPISendWait(tag)(*isend(var, dest, tag)) return MPISendWait(tag)(*isend(var, dest, tag))
def irecv(shape, dtype, source, tag): def irecv(shape, dtype, source, tag):
"""
non-blocking receive
"""
return MPIRecv(source, tag, shape, dtype)() return MPIRecv(source, tag, shape, dtype)()
def recv(shape, dtype, source, tag): def recv(shape, dtype, source, tag):
"""
blocking receive
"""
return MPIRecvWait(tag)(*irecv(shape, dtype, source, tag)) return MPIRecvWait(tag)(*irecv(shape, dtype, source, tag))
# Ordering keys for scheduling # Ordering keys for scheduling
def mpi_send_wait_key(a): def mpi_send_wait_key(a):
""" Wait as long as possible on Waits, Start Send/Recvs early """ """ Wait as long as possible on Waits, Start Send/Recvs early """
...@@ -268,6 +292,7 @@ def mpi_send_wait_key(a): ...@@ -268,6 +292,7 @@ def mpi_send_wait_key(a):
return -1 return -1
return 0 return 0
def mpi_tag_key(a): def mpi_tag_key(a):
""" Break MPI ties by using the variable tag - prefer lower tags first """ """ Break MPI ties by using the variable tag - prefer lower tags first """
if isinstance(a.op, (MPISend, MPIRecv, MPIRecvWait, MPISendWait)): if isinstance(a.op, (MPISend, MPIRecv, MPIRecvWait, MPISendWait)):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论