提交 bd7a1d18 authored 作者: James Bergstra's avatar James Bergstra

using new config variable in sandbox/cuda

上级 e41e8502
import os, sys import os, sys
from theano.gof.compiledir import get_compiledir from theano.gof.compiledir import get_compiledir
from theano.compile import optdb from theano.compile import optdb
import theano.config as config from theano import config
import logging, copy import logging, copy
_logger_name = 'theano_cuda_ndarray' _logger_name = 'theano_cuda_ndarray'
...@@ -109,7 +109,13 @@ if enable_cuda: ...@@ -109,7 +109,13 @@ if enable_cuda:
import cuda_ndarray import cuda_ndarray
def use(device=config.THEANO_GPU): def use(device=config.device):
if device.startswith('gpu'):
device = int(device[3:])
elif device == 'cpu':
device = -1
else:
raise ValueError("Invalid device identifier", device)
if use.device_number is None: if use.device_number is None:
# No successful call to use() has been made yet # No successful call to use() has been made yet
if device=="-1" or device=="CPU": if device=="-1" or device=="CPU":
...@@ -144,5 +150,5 @@ def handle_shared_float32(tf): ...@@ -144,5 +150,5 @@ def handle_shared_float32(tf):
raise NotImplementedError('removing our handler') raise NotImplementedError('removing our handler')
if enable_cuda and config.THEANO_GPU not in [None, ""]: if enable_cuda and config.device.startswith('gpu'):
use() use()
import sys, os, subprocess, logging import sys, os, subprocess, logging
from theano.gof.cmodule import (std_libs, std_lib_dirs, std_include_dirs, dlimport, from theano.gof.cmodule import (std_libs, std_lib_dirs, std_include_dirs, dlimport,
get_lib_extension) get_lib_extension)
import theano.config as config from theano import config
_logger=logging.getLogger("theano_cuda_ndarray.nvcc_compiler") _logger=logging.getLogger("theano_cuda_ndarray.nvcc_compiler")
_logger.setLevel(logging.WARN) _logger.setLevel(logging.WARN)
...@@ -46,7 +46,7 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[ ...@@ -46,7 +46,7 @@ def nvcc_module_compile_str(module_name, src_code, location=None, include_dirs=[
else: preargs = list(preargs) else: preargs = list(preargs)
preargs.append('-fPIC') preargs.append('-fPIC')
no_opt = False no_opt = False
cuda_root = config.CUDA_ROOT cuda_root = config.cuda.root
include_dirs = std_include_dirs() + include_dirs include_dirs = std_include_dirs() + include_dirs
libs = std_libs() + ['cudart'] + libs libs = std_libs() + ['cudart'] + libs
lib_dirs = std_lib_dirs() + lib_dirs lib_dirs = std_lib_dirs() + lib_dirs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论