提交 3acd26ae authored 作者: Frederic Bastien's avatar Frederic Bastien

Make a gpu test move computation to the GPU.

上级 6f0addc8
...@@ -13,7 +13,7 @@ from .config import mode_with_gpu, mode_without_gpu, test_ctx_name ...@@ -13,7 +13,7 @@ from .config import mode_with_gpu, mode_without_gpu, test_ctx_name
from .test_basic_ops import rand_gpuarray from .test_basic_ops import rand_gpuarray
from ..elemwise import (GpuElemwise, GpuDimShuffle, from ..elemwise import (GpuElemwise, GpuDimShuffle,
GpuCAReduceCuda, GpuCAReduceCPY, GpuErfinv, GpuErfcinv) GpuCAReduceCuda, GpuCAReduceCPY, GpuErfinv, GpuErfcinv)
from ..type import GpuArrayType, get_context from ..type import GpuArrayType, get_context, gpuarray_shared_constructor
from pygpu import ndgpuarray as gpuarray from pygpu import ndgpuarray as gpuarray
...@@ -40,16 +40,21 @@ def test_elemwise_pow(): ...@@ -40,16 +40,21 @@ def test_elemwise_pow():
for dtype_exp in dtypes: for dtype_exp in dtypes:
# Compile a gpu function with the specified dtypes # Compile a gpu function with the specified dtypes
base = theano.tensor.vector(dtype=dtype_base)
exp = theano.tensor.vector(dtype=dtype_exp)
output = base ** exp
f = theano.function([base, exp], output)
base_val = np.random.randint(0, 5, size=10).astype(dtype_base) base_val = np.random.randint(0, 5, size=10).astype(dtype_base)
exp_val = np.random.randint(0, 3, size=10).astype(dtype_exp) exp_val = np.random.randint(0, 3, size=10).astype(dtype_exp)
base = theano.tensor.vector(dtype=dtype_base)
exp = gpuarray_shared_constructor(exp_val)
output = base ** exp
f = theano.function([base], output, mode=mode_with_gpu)
theano.printing.debugprint(f)
# We don't transfer to the GPU when the output dtype is int*
n = len([n for n in f.maker.fgraph.apply_nodes
if isinstance(n.op, GpuElemwise)])
assert n == int("float" in output.dtype)
# Call the function to make sure the output is valid # Call the function to make sure the output is valid
out = f(base_val, exp_val) out = f(base_val)
expected_out = base_val ** exp_val expected_out = base_val ** exp_val
assert_allclose(out, expected_out) assert_allclose(out, expected_out)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论