提交 a4fedfed authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add checks for init_gpu_device.

上级 4a4e9710
...@@ -36,12 +36,14 @@ def init_dev(dev, name=None): ...@@ -36,12 +36,14 @@ def init_dev(dev, name=None):
if dev not in init_dev.devmap: if dev not in init_dev.devmap:
init_dev.devmap[dev] = pygpu.init(dev) init_dev.devmap[dev] = pygpu.init(dev)
context = init_dev.devmap[dev] context = init_dev.devmap[dev]
# This will map the context name to the real context object.
reg_context(name, context) reg_context(name, context)
pygpu_activated = True pygpu_activated = True
if config.print_active_device: if config.print_active_device:
print("Mapped name %s to device %s: %s" % (name, dev, context.devname), print("Mapped name %s to device %s: %s" % (name, dev, context.devname),
file=sys.stderr) file=sys.stderr)
# This maps things like 'cuda0' to the context object on that device.
init_dev.devmap = {} init_dev.devmap = {}
if pygpu: if pygpu:
...@@ -54,10 +56,14 @@ if pygpu: ...@@ -54,10 +56,14 @@ if pygpu:
optdb.add_tags('gpuarray_opt', 'fast_run', 'fast_compile') optdb.add_tags('gpuarray_opt', 'fast_run', 'fast_compile')
elif (config.init_gpu_device.startswith('cuda') or elif (config.init_gpu_device.startswith('cuda') or
config.init_gpu_device.startswith('opencl')): config.init_gpu_device.startswith('opencl')):
if config.device != 'gpu':
raise ValueError('you must set device=gpu to use init_gpu_device.')
if config.contexts != '':
print("Using contexts will make init_gpu_device act like device and move all computations by default, which might not be what you want.")
init_dev(config.init_gpu_device) init_dev(config.init_gpu_device)
if config.contexts != '': if config.contexts != '':
for n, d in (c.split('->') for c in config.contexts.split(';')): for n, d in (c.split('->') for c in config.contexts.split(';')):
init_dev(d, n) init_dev(d.strip(), n.strip())
import theano.compile import theano.compile
theano.compile.shared_constructor(gpuarray_shared_constructor) theano.compile.shared_constructor(gpuarray_shared_constructor)
optdb.add_tags('gpuarray_opt', 'fast_run', 'fast_compile') optdb.add_tags('gpuarray_opt', 'fast_run', 'fast_compile')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论