提交 63b17b09 authored 作者: Frederic Bastien's avatar Frederic Bastien

'use a new way to init pycuda to work correctly with newer version of Theano and Pycuda.'

上级 f608b307
......@@ -22,13 +22,15 @@ from theano.sandbox.cuda import GpuElemwise, CudaNdarrayType
from theano.sandbox.cuda.basic_ops import as_cuda_ndarray_variable, gpu_contiguous
from theano.sandbox.cuda.opt import gpu_seqopt
import pycuda_init
if not pycuda_init.pycuda_available:
raise Exception("No pycuda available. You can't load pycuda_example.py")
import pycuda
from pycuda.elementwise import ElementwiseKernel
from pycuda.compiler import SourceModule
from pycuda.gpuarray import splay
from pycuda.tools import VectorArg
import pycuda.autoinit
def theano_parse_c_arg(c_arg):
c_arg = c_arg.replace('npy_float32','float')
c_arg = c_arg.replace('npy_float64','double')
......
import os
import theano
def select_gpu_from_theano():
# Transfer the theano gpu binding to pycuda, for consistency
theano_to_pycuda_device_map = {"cpu": "0",
"gpu0": "0",
"gpu1": "1",
"gpu2": "2",
"gpu3": "3"}
os.environ["CUDA_DEVICE"] = theano_to_pycuda_device_map.get(theano.config.device, "0")
select_gpu_from_theano()
pycuda_available = False
try:
import pycuda
import pycuda.autoinit
pycuda_available = True
except ImportError:
pass
import numpy
try:
import pycuda
except ImportError:
import theano
import theano.misc.pycuda_init
if not theano.misc.pycuda_init.pycuda_available:
from nose.plugins.skip import SkipTest
raise SkipTest("Pycuda not installed. Skip test of theano op with pycuda code.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论