提交 7a6a2712 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

clean up _test_mpi_roundtrip

mpi_cmp -> mpi_cmps fix incorrect test `y = x + x` -> `y = x + 1` to match expected output remove old comments
上级 270a4a7a
# Run using # Run using
# mpiexec -np 2 python _test_mpi_roundtrip.py # mpiexec -np 2 python _test_mpi_roundtrip.py
from mpi4py import MPI from mpi4py import MPI
comm = MPI.COMM_WORLD comm = MPI.COMM_WORLD
import theano import theano
from theano.tensor.io import send, recv, mpi_cmp from theano.tensor.io import send, recv, mpi_cmps
from theano.gof.sched import sort_schedule_fn from theano.gof.sched import sort_schedule_fn
import numpy as np import numpy as np
from sys import stdout from sys import stdout
...@@ -16,15 +15,14 @@ size = comm.Get_size() ...@@ -16,15 +15,14 @@ size = comm.Get_size()
shape = (2, 2) shape = (2, 2)
dtype = 'float32' dtype = 'float32'
scheduler = sort_schedule_fn(mpi_cmp) scheduler = sort_schedule_fn(*mpi_cmps)
mode = theano.Mode(optimizer=None, mode = theano.Mode(optimizer=None,
linker=theano.OpWiseCLinker(schedule=scheduler)) linker=theano.OpWiseCLinker(schedule=scheduler))
if rank == 0: if rank == 0:
x = theano.tensor.matrix('x', dtype=dtype) x = theano.tensor.matrix('x', dtype=dtype)
y = x + x y = x + 1
send_request = send(y, 1, 11) send_request = send(y, 1, 11)
# send_request = send(x, 1, 11)
z = recv(shape, dtype, 1, 12) z = recv(shape, dtype, 1, 12)
...@@ -32,7 +30,6 @@ if rank == 0: ...@@ -32,7 +30,6 @@ if rank == 0:
xx = np.random.rand(*shape).astype(dtype) xx = np.random.rand(*shape).astype(dtype)
expected = (xx + 1) * 2 expected = (xx + 1) * 2
# expected = xx * 2
_, zz = f(xx) _, zz = f(xx)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论