提交 b4cd9a55 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

added live mpi test

上级 f0e5c329
# Run using
# mpiexec -np 2 python _test_mpi_roundtrip.py
from mpi4py import MPI
comm = MPI.COMM_WORLD
import theano
......@@ -8,26 +12,29 @@ from sys import stdout
rank = comm.Get_rank()
size = comm.Get_size()
print size
print rank
shape = (10, 10)
shape = (2, 2)
dtype = 'float32'
mode = theano.Mode(optimizer=None, linker='py')
if rank == 0:
x = theano.tensor.matrix('x', dtype=dtype)
y = x + 1
# y = x + x
# send_request = send(y, 1, 11)
send_request = send(x, 1, 11)
z = recv(shape, dtype, 1, 12)
f = theano.function([x], [send_request, z])
f = theano.function([x], [send_request, z], mode=mode)
xx = np.random.rand(*shape).astype(dtype)
# expected = (xx + 1) * 2
expected = xx * 2
zz = f(xx)
_, zz = f(xx)
same = np.linalg.norm(zz - (xx+1)*2) < .001
same = np.linalg.norm(zz - expected) < .001
stdout.write(str(same))
if rank == 1:
......@@ -36,6 +43,6 @@ if rank == 1:
z = y * 2
send_request = send(z, 0, 12)
f = theano.function([], send_request)
f = theano.function([], send_request, mode=mode)
f()
......@@ -16,4 +16,3 @@ def test_send():
sendnode = y.owner.inputs[0].owner
assert sendnode.op.dest == 1
assert sendnode.op.tag == 11
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论