提交 270a4a7a authored 作者: Matthew Rocklin's avatar Matthew Rocklin

rename mpi_cmp to many cmps

上级 12586a61
......@@ -271,5 +271,8 @@ def mpi_tag_key(a):
else:
return 0
mpi_cmp_keys = (mpi_send_wait_key, mpi_tag_key)
mpi_cmps = map(key_to_cmp, mpi_cmp_keys)
mpi_send_wait_cmp = key_to_cmp(mpi_send_wait_key)
mpi_tag_cmp = key_to_cmp(mpi_tag_key)
mpi_keys = (mpi_send_wait_key, mpi_tag_key)
mpi_cmps = (mpi_send_wait_cmp, mpi_tag_cmp)
from theano.tensor.io import send, recv, mpi_cmps, MPISend, MPISendWait
from theano.tensor.io import (send, recv, mpi_cmps, MPISend, MPISendWait,
mpi_send_wait_cmp, mpi_tag_cmp)
import theano
import subprocess
import os
......@@ -39,15 +40,15 @@ def test_mpi_roundtrip():
"theano/tensor/tests/_test_mpi_roundtrip.py").read()
assert result == "True"
def test_mpi_cmp():
def test_mpi_send_wait_cmp():
x = theano.tensor.matrix('x')
y = send(x, 1, 11)
z = x + x
waitnode = y.owner
sendnode = y.owner.inputs[0].owner
addnode = z.owner
assert mpi_cmp(sendnode, addnode) < 0 # send happens first
assert mpi_cmp(waitnode, addnode) > 0 # wait happens last
assert mpi_send_wait_cmp(sendnode, addnode) < 0 # send happens first
assert mpi_send_wait_cmp(waitnode, addnode) > 0 # wait happens last
def test_mpi_tag_ordering():
x = recv((2,2), 'float32', 1, 12)
......@@ -72,4 +73,3 @@ def test_mpi_schedule():
optypes = [MPISend, theano.tensor.Elemwise, MPISendWait]
assert all(isinstance(node.op, optype)
for node, optype in zip(nodes, optypes))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论