提交 2bd44678 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a contexts config variable to specify the context map.

上级 8c9b612b
...@@ -112,7 +112,8 @@ if config.device.startswith('gpu') or config.init_gpu_device.startswith('gpu'): ...@@ -112,7 +112,8 @@ if config.device.startswith('gpu') or config.init_gpu_device.startswith('gpu'):
if (config.device.startswith('cuda') or if (config.device.startswith('cuda') or
config.device.startswith('opencl') or config.device.startswith('opencl') or
config.init_gpu_device.startswith('cuda') or config.init_gpu_device.startswith('cuda') or
config.init_gpu_device.startswith('opencl')): config.init_gpu_device.startswith('opencl') or
config.contexts != ''):
import theano.sandbox.gpuarray import theano.sandbox.gpuarray
# Use config.numpy to call numpy.seterr # Use config.numpy to call numpy.seterr
......
...@@ -111,6 +111,29 @@ AddConfigVar( ...@@ -111,6 +111,29 @@ AddConfigVar(
BoolParam(False, allow_override=False), BoolParam(False, allow_override=False),
in_c_key=False) in_c_key=False)
class ContextsParam(ConfigParam):
def __init__(self):
def filter(val):
if val == '':
return val
for v in val.split(';'):
s = v.split('->')
if len(s) != 2:
raise ValueError("Malformed context map: %s" % (v,))
return val
ConfigParam.__init__(self, '', filter, False)
AddConfigVar(
'contexts',
"""
Context map for multi-gpu operation. Format is a
semicolon-separated list of names and device names in the
'name->dev_name' format. An example that would map name 'test' to
device 'cuda0' and name 'test2' to device 'opencl0:0' follows:
"test->cuda0;test2->opencl0:0".
""", ContextsParam(), in_c_key=False)
AddConfigVar( AddConfigVar(
'print_active_device', 'print_active_device',
"Print active device at when the GPU device is initialized.", "Print active device at when the GPU device is initialized.",
......
...@@ -51,6 +51,12 @@ if pygpu: ...@@ -51,6 +51,12 @@ if pygpu:
elif (config.init_gpu_device.startswith('cuda') or elif (config.init_gpu_device.startswith('cuda') or
config.init_gpu_device.startswith('opencl')): config.init_gpu_device.startswith('opencl')):
init_dev(config.init_gpu_device) init_dev(config.init_gpu_device)
if config.contexts != '':
for n, d in (c.split('->') for c in config.contexts.split(';')):
init_dev(d, n)
import theano.compile
theano.compile.shared_constructor(gpuarray_shared_constructor)
optdb.add_tags('gpuarray_opt', 'fast_run', 'fast_compile')
from .basic_ops import (GpuAlloc, GpuContiguous, GpuEye, GpuFromHost, from .basic_ops import (GpuAlloc, GpuContiguous, GpuEye, GpuFromHost,
GpuJoin, GpuReshape, GpuSplit, HostFromGpu) GpuJoin, GpuReshape, GpuSplit, HostFromGpu)
...@@ -65,5 +71,6 @@ else: ...@@ -65,5 +71,6 @@ else:
if (config.init_gpu_device.startswith('cuda') or if (config.init_gpu_device.startswith('cuda') or
config.init_gpu_device.startswith('opencl') or config.init_gpu_device.startswith('opencl') or
config.device.startswith('opencl') or config.device.startswith('opencl') or
config.device.startswith('cuda')): config.device.startswith('cuda') or
config.contexts != ''):
error("pygpu was configured but could not be imported", exc_info=True) error("pygpu was configured but could not be imported", exc_info=True)
...@@ -170,8 +170,8 @@ class InputToGpuOptimizer(Optimizer): ...@@ -170,8 +170,8 @@ class InputToGpuOptimizer(Optimizer):
continue continue
try: try:
ctx = getattr(input.tag, 'context_name', None) ctx_name = getattr(input.tag, 'context_name', None)
new_input = host_from_gpu(GpuFromHost(ctx)(input)) new_input = host_from_gpu(GpuFromHost(ctx_name)(input))
fgraph.replace_validate(input, new_input, fgraph.replace_validate(input, new_input,
"InputToGpuOptimizer") "InputToGpuOptimizer")
except TypeError: except TypeError:
...@@ -180,7 +180,8 @@ class InputToGpuOptimizer(Optimizer): ...@@ -180,7 +180,8 @@ class InputToGpuOptimizer(Optimizer):
except ValueError: except ValueError:
# If there is no context tag and no default context # If there is no context tag and no default context
# then it stays on the CPU # then it stays on the CPU
assert ctx is None if ctx is not None:
raise
pass pass
......
...@@ -469,6 +469,15 @@ def gpuarray_shared_constructor(value, name=None, strict=False, ...@@ -469,6 +469,15 @@ def gpuarray_shared_constructor(value, name=None, strict=False,
if not isinstance(value, (numpy.ndarray, pygpu.gpuarray.GpuArray)): if not isinstance(value, (numpy.ndarray, pygpu.gpuarray.GpuArray)):
raise TypeError('ndarray or GpuArray required') raise TypeError('ndarray or GpuArray required')
try:
get_context(context_name)
except ValueError:
# Don't make this a hard error if we attempt to make a shared
# variable while there is no default context.
if context_name is None:
raise TypeError('No default context and no context specified')
raise
if broadcastable is None: if broadcastable is None:
broadcastable = (False,) * value.ndim broadcastable = (False,) * value.ndim
type = GpuArrayType(value.dtype, broadcastable) type = GpuArrayType(value.dtype, broadcastable)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论