提交 cb167f3b authored 作者: Matthew Rocklin's avatar Matthew Rocklin

add mpi comparator function

上级 a0748414
...@@ -251,3 +251,20 @@ def irecv(shape, dtype, source, tag): ...@@ -251,3 +251,20 @@ def irecv(shape, dtype, source, tag):
return MPIRecv(source, tag, shape, dtype)() return MPIRecv(source, tag, shape, dtype)()
def recv(shape, dtype, source, tag): def recv(shape, dtype, source, tag):
return MPIRecvWait()(*irecv(shape, dtype, source, tag)) return MPIRecvWait()(*irecv(shape, dtype, source, tag))
def mpi_key(a):
# Wait as long as possible on Waits
if isinstance(a.op, (MPIRecvWait, MPISendWait)):
return (1,)
# Start async communication as soon as possible
# Break ties by the variable tag
if isinstance(a.op, (MPIRecv, MPISend)):
return (-1, a.op.tag)
# Everything else is normal
return (0,)
def mpi_cmp(a, b):
"""
A comparator function to optimize MPI computation/communicaiton overlap
"""
return cmp(mpi_key(a), mpi_key(b))
from theano.tensor.io import send, recv from theano.tensor.io import send, recv, mpi_cmp, MPISend, MPISendWait
import theano import theano
import subprocess import subprocess
import os import os
...@@ -34,3 +34,33 @@ def test_mpi_roundtrip(): ...@@ -34,3 +34,33 @@ def test_mpi_roundtrip():
result = os.popen("mpiexec -np 2 python " result = os.popen("mpiexec -np 2 python "
"theano/tensor/tests/_test_mpi_roundtrip.py").read() "theano/tensor/tests/_test_mpi_roundtrip.py").read()
assert result == "True" assert result == "True"
def test_mpi_cmp():
x = theano.tensor.matrix('x')
y = send(x, 1, 11)
z = x + x
waitnode = y.owner
sendnode = y.owner.inputs[0].owner
addnode = z.owner
assert mpi_cmp(sendnode, addnode) < 0 # send happens first
assert mpi_cmp(waitnode, addnode) > 0 # wait happens last
def test_mpi_schedule():
from theano.gof.graph import sort_schedule_fn
scheduler = sort_schedule_fn(mpi_cmp)
linker = theano.OpWiseCLinker(schedule=scheduler)
mode = theano.Mode(linker=linker)
x = theano.tensor.matrix('x')
y = send(x, 1, 11)
z = x + x
waitnode = y.owner
sendnode = y.owner.inputs[0].owner
addnode = z.owner
f = theano.function([x], [y, z], mode=mode)
nodes = f.maker.linker.make_all()[-1]
optypes = [MPISend, theano.tensor.Elemwise, MPISendWait]
assert all(isinstance(node.op, optype)
for node, optype in zip(nodes, optypes))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论