提交 c0d3bff3 authored 作者: Colin Raffel's avatar Colin Raffel

Adding GpuSplit test

上级 e30ffc3c
...@@ -291,6 +291,27 @@ def test_local_gpu_subtensor(): ...@@ -291,6 +291,27 @@ def test_local_gpu_subtensor():
assert any([isinstance(node.op, cuda.GpuElemwise) for node in topo]) assert any([isinstance(node.op, cuda.GpuElemwise) for node in topo])
def test_local_split():
""" Test that the GpuSplit op is being applied and works """
# Construct symbolic split
x = tensor.vector()
splits = tensor.lvector()
ra, rb, rc = tensor.split(x, splits, n_splits=3, axis=0)
# Compile function to use CPU
f = theano.function([x, splits], [ra, rb, rc], mode=mode_without_gpu)
# Get values for CPU version
cpu_res = f([0, 1, 2, 3, 4, 5], [3, 2, 1])
l = f.maker.fgraph.toposort()
# Ensure that one op is theano.tensor.Split
assert any([isinstance(o.op, theano.tensor.Split) for o in l])
# GPU version
f = theano.function([x, splits], [ra, rb, rc], mode=mode_with_gpu)
gpu_res = f([0, 1, 2, 3, 4, 5], [3, 2, 1])
l = f.maker.fgraph.toposort()
assert any([isinstance(o.op, theano.sandbox.cuda.GpuSplit) for o in l])
# Check equality
assert all([(cpu == gpu).all() for cpu, gpu in zip(cpu_res, gpu_res)])
def test_print_op(): def test_print_op():
""" Test that print ops don't block gpu optimization""" """ Test that print ops don't block gpu optimization"""
b = tensor.fmatrix() b = tensor.fmatrix()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论