提交 3aaa820f authored 作者: Matthew Rocklin's avatar Matthew Rocklin

break up mpi_key into two keys

上级 616299c3
...@@ -170,3 +170,8 @@ def sort_schedule_fn(*cmps): ...@@ -170,3 +170,8 @@ def sort_schedule_fn(*cmps):
""" Order nodes in a FunctionGraph """ """ Order nodes in a FunctionGraph """
return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps) return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps)
return schedule return schedule
def key_to_cmp(key):
def key_cmp(a, b):
return cmp(key(a), key(b))
return key_cmp
import numpy import numpy
from theano import gof from theano import gof
from theano.gof import Constant, Generic, Op from theano.gof import Constant, Generic, Op
from theano.gof.sched import key_to_cmp
from basic import tensor from basic import tensor
########################## ##########################
# Disk Access # Disk Access
...@@ -254,21 +255,21 @@ def irecv(shape, dtype, source, tag): ...@@ -254,21 +255,21 @@ def irecv(shape, dtype, source, tag):
def recv(shape, dtype, source, tag): def recv(shape, dtype, source, tag):
return MPIRecvWait(tag)(*irecv(shape, dtype, source, tag)) return MPIRecvWait(tag)(*irecv(shape, dtype, source, tag))
def mpi_key(a): # Ordering keys for scheduling
# Wait as long as possible on Waits def mpi_send_wait_key(a):
""" Wait as long as possible on Waits, Start Send/Recvs early """
if isinstance(a.op, (MPIRecvWait, MPISendWait)): if isinstance(a.op, (MPIRecvWait, MPISendWait)):
return ( 1, a.op.tag) return 1
# Start async communication as soon as possible
if isinstance(a.op, (MPIRecv, MPISend)): if isinstance(a.op, (MPIRecv, MPISend)):
return (-1, a.op.tag) return -1
return 0
# Break ties by the tag. Earlier tags first.
def mpi_tag_key(a):
# Everything else is normal """ Break MPI ties by using the variable tag - prefer lower tags first """
return (0,) if isinstance(a.op, (MPISend, MPIRecv, MPIRecvWait, MPISendWait)):
return a.op.tag
def mpi_cmp(a, b): else:
""" return 0
A comparator function to optimize MPI computation/communicaiton overlap
""" mpi_cmp_keys = (mpi_send_wait_key, mpi_tag_key)
return cmp(mpi_key(a), mpi_key(b)) mpi_cmps = map(key_to_cmp, mpi_cmp_keys)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论