提交 3aaa820f authored 作者: Matthew Rocklin's avatar Matthew Rocklin

break up mpi_key into two keys

上级 616299c3
......@@ -170,3 +170,8 @@ def sort_schedule_fn(*cmps):
""" Order nodes in a FunctionGraph """
return sort_apply_nodes(fgraph.inputs, fgraph.outputs, cmps)
return schedule
def key_to_cmp(key):
def key_cmp(a, b):
return cmp(key(a), key(b))
return key_cmp
import numpy
from theano import gof
from theano.gof import Constant, Generic, Op
from theano.gof.sched import key_to_cmp
from basic import tensor
##########################
# Disk Access
......@@ -254,21 +255,21 @@ def irecv(shape, dtype, source, tag):
def recv(shape, dtype, source, tag):
return MPIRecvWait(tag)(*irecv(shape, dtype, source, tag))
def mpi_key(a):
# Wait as long as possible on Waits
# Ordering keys for scheduling
def mpi_send_wait_key(a):
""" Wait as long as possible on Waits, Start Send/Recvs early """
if isinstance(a.op, (MPIRecvWait, MPISendWait)):
return ( 1, a.op.tag)
# Start async communication as soon as possible
return 1
if isinstance(a.op, (MPIRecv, MPISend)):
return (-1, a.op.tag)
# Break ties by the tag. Earlier tags first.
# Everything else is normal
return (0,)
def mpi_cmp(a, b):
"""
A comparator function to optimize MPI computation/communicaiton overlap
"""
return cmp(mpi_key(a), mpi_key(b))
return -1
return 0
def mpi_tag_key(a):
""" Break MPI ties by using the variable tag - prefer lower tags first """
if isinstance(a.op, (MPISend, MPIRecv, MPIRecvWait, MPISendWait)):
return a.op.tag
else:
return 0
mpi_cmp_keys = (mpi_send_wait_key, mpi_tag_key)
mpi_cmps = map(key_to_cmp, mpi_cmp_keys)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论