提交 f0e5c329 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

add send/recv functions, tests.

Also removed infer_shape for now
上级 1f59b11e
...@@ -101,13 +101,13 @@ class MPIRecv(Op): ...@@ -101,13 +101,13 @@ class MPIRecv(Op):
@note: Non-differentiable. @note: Non-differentiable.
""" """
def __init__(self, rank, tag, dtype, shape): def __init__(self, source, tag, shape, dtype):
self.rank = rank self.source = source
self.tag = tag self.tag = tag
self.shape = shape self.shape = shape
self.dtype = numpy.dtype(dtype) # turn "float64" into numpy.float64 self.dtype = numpy.dtype(dtype) # turn "float64" into numpy.float64
self.broadcastable = (False,)*len(shape) self.broadcastable = (False,)*len(shape)
self._info = (rank, tag, dtype, shape) self._info = (source, tag, shape, dtype)
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and self._info == other._info) return (type(self) == type(other) and self._info == other._info)
...@@ -116,19 +116,22 @@ class MPIRecv(Op): ...@@ -116,19 +116,22 @@ class MPIRecv(Op):
return hash(self._info) return hash(self._info)
def make_node(self): def make_node(self):
return gof.Apply(self, [], [theano.Generic(), return gof.Apply(self, [], [theano.Variable(Generic()),
tensor(self.dtype, tensor(self.dtype,
broadcastable=self.broadcastable)]) broadcastable=self.broadcastable)])
def perform(self, node, inp, out): def perform(self, node, inp, out):
data = numpy.empty(self.shape, dtype=self.dtype) data = numpy.empty(self.shape, dtype=self.dtype)
request = comm.Irecv(data, self.rank, self.tag) request = comm.Irecv(data, self.source, self.tag)
out[0][0] = request out[0][0] = request
out[0][1] = data out[1][0] = data
def __str__(self): def __str__(self):
return "MPIRecv{source: %d, tag: %d, dtype:%s, shape:%s, :%s}"%self._info return "MPIRecv{source: %d, tag: %d, shape: %s, dtype: %s}"%self._info
#def infer_shape(self, node, shapes):
# return [self.shape]
class MPIRecvWait(Op): class MPIRecvWait(Op):
""" """
...@@ -147,18 +150,16 @@ class MPIRecvWait(Op): ...@@ -147,18 +150,16 @@ class MPIRecvWait(Op):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(self.type) return hash(type(self))
def make_node(self): def make_node(self, request, data):
return gof.Apply(self, [theano.Generic(), return gof.Apply(self, [request, data],
tensor(self.dtype, [tensor(data.dtype,
broadcastable=self.broadcastable)], broadcastable=data.broadcastable)])
[tensor(self.dtype,
broadcastable=self.broadcastable)])
def perform(self, node, inp, out): def perform(self, node, inp, out):
request = inp[0][0] request = inp[0]
data = inp[0][1] data = inp[1]
request.wait() request.wait()
...@@ -167,6 +168,9 @@ class MPIRecvWait(Op): ...@@ -167,6 +168,9 @@ class MPIRecvWait(Op):
def __str__(self): def __str__(self):
return "MPIRecvWait" return "MPIRecvWait"
# def infer_shape(self, node, shapes):
# return shapes
class MPISend(Op): class MPISend(Op):
""" """
An operation to asynchronously Send an array to a remote host using MPI An operation to asynchronously Send an array to a remote host using MPI
...@@ -178,10 +182,10 @@ class MPISend(Op): ...@@ -178,10 +182,10 @@ class MPISend(Op):
@note: Non-differentiable. @note: Non-differentiable.
""" """
def __init__(self, rank, tag): def __init__(self, dest, tag):
self.rank = rank self.dest = dest
self.tag = tag self.tag = tag
self._info = (rank, tag) self._info = (dest, tag)
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and self._info == other._info) return (type(self) == type(other) and self._info == other._info)
...@@ -189,15 +193,15 @@ class MPISend(Op): ...@@ -189,15 +193,15 @@ class MPISend(Op):
def __hash__(self): def __hash__(self):
return hash(self._info) return hash(self._info)
def make_node(self): def make_node(self, data):
return gof.Apply(self, [tensor(self.dtype, broadcastable=self.broadcastable)], return gof.Apply(self, [data],
[theano.Generic()]) [theano.Variable(Generic())])
def perform(self, node, inp, out): def perform(self, node, inp, out):
data = inp[0][0] data = inp[0]
request = comm.Isend(data, self.rank, self.tag) request = comm.Isend(data, self.dest, self.tag)
out[0][0] = request out[0][0] = request
...@@ -221,15 +225,26 @@ class MPISendWait(Op): ...@@ -221,15 +225,26 @@ class MPISendWait(Op):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self): def __hash__(self):
return hash(self.type) return hash(type(self))
def make_node(self): def make_node(self, request):
return gof.Apply(self, [theano.Generic()], [theano.Generic()]) return gof.Apply(self, [request],
[theano.Variable(Generic())])
def perform(self, node, inp, out): def perform(self, node, inp, out):
request = inp[0][0] request = inp[0]
request.wait() request.wait()
out[0][0] = True out[0][0] = True
def __str__(self): def __str__(self):
return "MPISendWait" return "MPISendWait"
def isend(var, dest, tag):
return MPISend(dest, tag)(var)
def send(var, dest, tag):
return MPISendWait()(isend(var, dest, tag))
def irecv(shape, dtype, source, tag):
return MPIRecv(source, tag, shape, dtype)()
def recv(shape, dtype, source, tag):
return MPIRecvWait()(*irecv(shape, dtype, source, tag))
from mpi4py import MPI
comm = MPI.COMM_WORLD
import theano
from theano.tensor.io import send, recv
import numpy as np
from sys import stdout
rank = comm.Get_rank()
size = comm.Get_size()
print size
print rank
shape = (10, 10)
dtype = 'float32'
if rank == 0:
x = theano.tensor.matrix('x', dtype=dtype)
y = x + 1
send_request = send(x, 1, 11)
z = recv(shape, dtype, 1, 12)
f = theano.function([x], [send_request, z])
xx = np.random.rand(*shape).astype(dtype)
zz = f(xx)
same = np.linalg.norm(zz - (xx+1)*2) < .001
stdout.write(str(same))
if rank == 1:
y = recv(shape, dtype, 0, 11)
z = y * 2
send_request = send(z, 0, 12)
f = theano.function([], send_request)
f()
from theano.tensor.io import send, recv
import theano
def test_recv():
x = recv((10,10), 'float64', 0, 11)
assert x.dtype == 'float64'
assert x.broadcastable == (False, False)
recvnode = x.owner.inputs[0].owner
assert recvnode.op.source == 0
assert recvnode.op.tag == 11
def test_send():
x = theano.tensor.matrix('x')
y = send(x, 1, 11)
sendnode = y.owner.inputs[0].owner
assert sendnode.op.dest == 1
assert sendnode.op.tag == 11
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论