提交 7f1c7cd8 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a target argument to shared variable constructor to have explicit

control over where it is allocated. This allows allocating shared variables on the CPU while the default is GPU or choosing which GPU to allocate on when doing multi-gpu.
上级 5bc35bc2
...@@ -159,7 +159,8 @@ CudaNdarrayType.SharedVariable = CudaNdarraySharedVariable ...@@ -159,7 +159,8 @@ CudaNdarrayType.SharedVariable = CudaNdarraySharedVariable
def cuda_shared_constructor(value, name=None, strict=False, def cuda_shared_constructor(value, name=None, strict=False,
allow_downcast=None, borrow=False, broadcastable=None): allow_downcast=None, borrow=False,
broadcastable=None):
""" """
SharedVariable Constructor for CudaNdarrayType. SharedVariable Constructor for CudaNdarrayType.
...@@ -193,12 +194,15 @@ def cuda_shared_constructor(value, name=None, strict=False, ...@@ -193,12 +194,15 @@ def cuda_shared_constructor(value, name=None, strict=False,
def float32_shared_constructor(value, name=None, strict=False, def float32_shared_constructor(value, name=None, strict=False,
allow_downcast=None, borrow=False, broadcastable=None): allow_downcast=None, borrow=False,
broadcastable=None, target='gpu'):
""" """
SharedVariable Constructor for CudaNdarrayType from numpy.ndarray or SharedVariable Constructor for CudaNdarrayType from numpy.ndarray or
CudaNdarray. CudaNdarray.
""" """
if target != 'gpu':
raise TypeError('not for gpu')
if theano.sandbox.cuda.use.device_number is None: if theano.sandbox.cuda.use.device_number is None:
theano.sandbox.cuda.use("gpu", theano.sandbox.cuda.use("gpu",
force=True, force=True,
......
...@@ -24,8 +24,8 @@ except ImportError: ...@@ -24,8 +24,8 @@ except ImportError:
# This is for documentation not to depend on the availability of pygpu # This is for documentation not to depend on the availability of pygpu
from .type import (GpuArrayType, GpuArrayVariable, GpuArrayConstant, from .type import (GpuArrayType, GpuArrayVariable, GpuArrayConstant,
GpuArraySharedVariable, gpuarray_shared_constructor, GpuArraySharedVariable, gpuarray_shared_constructor,
reg_context) reg_context, get_context, ContextNotDefined)
from .basic import as_gpuarray_variable from .basic_ops import as_gpuarray_variable
from . import opt, nerv from . import opt, nerv
def transfer(x, target): def transfer(x, target):
......
...@@ -472,27 +472,29 @@ GpuArrayType.SharedVariable = GpuArraySharedVariable ...@@ -472,27 +472,29 @@ GpuArrayType.SharedVariable = GpuArraySharedVariable
def gpuarray_shared_constructor(value, name=None, strict=False, def gpuarray_shared_constructor(value, name=None, strict=False,
allow_downcast=None, borrow=False, allow_downcast=None, borrow=False,
broadcastable=None, broadcastable=None, target=None):
context_name=None):
""" """
SharedVariable constructor for GpuArrayType. SharedVariable constructor for GpuArrayType.
""" """
if target == 'gpu' or target == 'cpu':
raise TypeError('not for me')
if not isinstance(value, (numpy.ndarray, pygpu.gpuarray.GpuArray)): if not isinstance(value, (numpy.ndarray, pygpu.gpuarray.GpuArray)):
raise TypeError('ndarray or GpuArray required') raise TypeError('ndarray or GpuArray required')
try: try:
get_context(context_name) get_context(target)
except ContextNotDefined: except ContextNotDefined:
# Don't make this a hard error if we attempt to make a shared # Don't make this a hard error if we attempt to make a shared
# variable while there is no default context. # variable while there is no default context.
if context_name is None: if target is None:
raise TypeError('No default context and no context specified') raise TypeError('No default context and no context specified')
raise raise
if broadcastable is None: if broadcastable is None:
broadcastable = (False,) * value.ndim broadcastable = (False,) * value.ndim
type = GpuArrayType(value.dtype, broadcastable, context_name=context_name) type = GpuArrayType(value.dtype, broadcastable, context_name=target)
deviceval = pygpu.gpuarray.array(value, copy=(not borrow), deviceval = pygpu.gpuarray.array(value, copy=(not borrow),
context=type.context) context=type.context)
return GpuArraySharedVariable(type=type, value=deviceval, name=name, return GpuArraySharedVariable(type=type, value=deviceval, name=name,
......
...@@ -24,7 +24,7 @@ class TensorSharedVariable(_tensor_py_operators, SharedVariable): ...@@ -24,7 +24,7 @@ class TensorSharedVariable(_tensor_py_operators, SharedVariable):
@shared_constructor @shared_constructor
def tensor_constructor(value, name=None, strict=False, allow_downcast=None, def tensor_constructor(value, name=None, strict=False, allow_downcast=None,
borrow=False, broadcastable=None): borrow=False, broadcastable=None, target='cpu'):
""" """
SharedVariable Constructor for TensorType. SharedVariable Constructor for TensorType.
...@@ -36,6 +36,9 @@ def tensor_constructor(value, name=None, strict=False, allow_downcast=None, ...@@ -36,6 +36,9 @@ def tensor_constructor(value, name=None, strict=False, allow_downcast=None,
The optional `broadcastable` argument will override this default. The optional `broadcastable` argument will override this default.
""" """
if target != 'cpu':
raise TypeError('not for cpu')
if not isinstance(value, numpy.ndarray): if not isinstance(value, numpy.ndarray):
raise TypeError() raise TypeError()
...@@ -65,7 +68,7 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable): ...@@ -65,7 +68,7 @@ class ScalarSharedVariable(_tensor_py_operators, SharedVariable):
@shared_constructor @shared_constructor
def scalar_constructor(value, name=None, strict=False, allow_downcast=None, def scalar_constructor(value, name=None, strict=False, allow_downcast=None,
borrow=False): borrow=False, target='cpu'):
""" """
SharedVariable constructor for scalar values. Default: int64 or float64. SharedVariable constructor for scalar values. Default: int64 or float64.
...@@ -78,6 +81,9 @@ def scalar_constructor(value, name=None, strict=False, allow_downcast=None, ...@@ -78,6 +81,9 @@ def scalar_constructor(value, name=None, strict=False, allow_downcast=None,
borrow, as it is a hint to Theano that we can reuse it. borrow, as it is a hint to Theano that we can reuse it.
""" """
if target != 'cpu':
raise TypeError('not for cpu')
if not isinstance(value, (numpy.number, float, int, complex)): if not isinstance(value, (numpy.number, float, int, complex)):
raise TypeError() raise TypeError()
try: try:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论