提交 6dc21df7 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Reorganize use()

上级 dae6a1ab
......@@ -13,21 +13,30 @@ import theano.compile.sandbox
import logging, os
def use():
def use(device=None):
if use.device_number is None:
# No successful call to use() has been made yet
if device is None:
device = int(os.getenv("THEANO_GPU",0))
try:
cuda_ndarray.gpu_init(device)
handle_shared_float32(True)
use.device_number = device
except RuntimeError, e:
logging.getLogger('theano_cuda_ndarray').warning("WARNING: Won't use the GPU as the initialisation of device %i failed. %s" %(device, e))
raise
elif use.device_number != device:
logging.getLogger('theano_cuda_ndarray').warning("WARNING: ignoring call to use(%s), GPU number %i is already in use." %(str(device), use.device_number))
use.device_number = None
def handle_shared_float32(tf):
"""Set the CudaNdarrayType as the default handler for shared float32 arrays
"""Set the CudaNdarrayType as the default handler for shared float32 arrays.
Use use(tf) instead as this is a bad name.
This function is intended to be called from use(gpu_index), not directly.
"""
if tf:
try:
v=os.getenv("THEANO_GPU",0)
cuda_ndarray.gpu_init(int(v))
theano.compile.sandbox.shared_constructor(shared_constructor)
except RuntimeError, e:
logging.getLogger('theano_cuda_ndarray').warning("WARNING: Won't use the GPU as the initialisation failed."+str(e))
else:
raise NotImplementedError('removing our handler')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论