提交 d81b1138 authored 作者: Matthew Rocklin's avatar Matthew Rocklin

add tags to MPIWait ops to break ties

上级 76cd9b63
......@@ -146,11 +146,11 @@ class MPIRecvWait(Op):
@note: Non-differentiable.
"""
def __init__(self):
pass
def __init__(self, tag):
self.tag = tag
def __eq__(self, other):
return type(self) == type(other)
return type(self) == type(other) and self.tag == other.tag
def __hash__(self):
return hash(type(self))
......@@ -223,11 +223,11 @@ class MPISendWait(Op):
@note: Non-differentiable.
"""
def __init__(self):
pass
def __init__(self, tag):
self.tag = tag
def __eq__(self, other):
return type(self) == type(other)
return type(self) == type(other) and self.tag == other.tag
def __hash__(self):
return hash(type(self))
......@@ -247,21 +247,23 @@ class MPISendWait(Op):
def isend(var, dest, tag):
return MPISend(dest, tag)(var)
def send(var, dest, tag):
return MPISendWait()(isend(var, dest, tag))
return MPISendWait(tag)(isend(var, dest, tag))
def irecv(shape, dtype, source, tag):
return MPIRecv(source, tag, shape, dtype)()
def recv(shape, dtype, source, tag):
return MPIRecvWait()(*irecv(shape, dtype, source, tag))
return MPIRecvWait(tag)(*irecv(shape, dtype, source, tag))
def mpi_key(a):
# Wait as long as possible on Waits
if isinstance(a.op, (MPIRecvWait, MPISendWait)):
return (1,)
return ( 1, a.op.tag)
# Start async communication as soon as possible
# Break ties by the variable tag
if isinstance(a.op, (MPIRecv, MPISend)):
return (-1, a.op.tag)
# Break ties by the tag. Earlier tags first.
# Everything else is normal
return (0,)
......
......@@ -2,6 +2,10 @@ from theano.tensor.io import send, recv, mpi_cmp, MPISend, MPISendWait
import theano
import subprocess
import os
from theano.gof.graph import sort_schedule_fn
mpi_scheduler = sort_schedule_fn(mpi_cmp)
mpi_linker = theano.OpWiseCLinker(schedule=mpi_scheduler)
mpi_mode = theano.Mode(linker=mpi_linker)
def test_recv():
x = recv((10,10), 'float64', 0, 11)
......@@ -45,12 +49,17 @@ def test_mpi_cmp():
assert mpi_cmp(sendnode, addnode) < 0 # send happens first
assert mpi_cmp(waitnode, addnode) > 0 # wait happens last
def test_mpi_schedule():
from theano.gof.graph import sort_schedule_fn
scheduler = sort_schedule_fn(mpi_cmp)
linker = theano.OpWiseCLinker(schedule=scheduler)
mode = theano.Mode(linker=linker)
def test_mpi_tag_ordering():
x = recv((2,2), 'float32', 1, 12)
y = recv((2,2), 'float32', 1, 11)
z = recv((2,2), 'float32', 1, 13)
f = theano.function([], [x,y,z], mode=mpi_mode)
nodes = f.maker.linker.make_all()[-1]
assert all(node.op.tag == tag
for node, tag in zip(nodes, (11,12,13,11,12,13)))
def test_mpi_schedule():
x = theano.tensor.matrix('x')
y = send(x, 1, 11)
z = x + x
......@@ -58,7 +67,7 @@ def test_mpi_schedule():
sendnode = y.owner.inputs[0].owner
addnode = z.owner
f = theano.function([x], [y, z], mode=mode)
f = theano.function([x], [y, z], mode=mpi_mode)
nodes = f.maker.linker.make_all()[-1]
optypes = [MPISend, theano.tensor.Elemwise, MPISendWait]
assert all(isinstance(node.op, optype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论