提交 cafe03af authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add support for contexts in GpuArrayType.

Still in transitory state, this should work with ops that rely on type context or default context (as long as they are equal). Will gradually move to type context only.
上级 71dea2cf
......@@ -21,11 +21,12 @@ except ImportError:
# This is for documentation not to depend on the availability of pygpu
from .type import (GpuArrayType, GpuArrayVariable, GpuArrayConstant,
GpuArraySharedVariable, gpuarray_shared_constructor)
GpuArraySharedVariable, gpuarray_shared_constructor,
reg_context)
from . import opt, nerv
def init_dev(dev):
def init_dev(dev, name=None):
if pygpu.gpuarray.api_version() != (-10000, 0):
raise RuntimeError("Wrong API version for gpuarray:",
pygpu.gpuarray.api_version(),
......@@ -33,14 +34,11 @@ def init_dev(dev):
"are in sync.")
global pygpu_activated
context = pygpu.init(dev)
pygpu.set_default_context(context)
reg_context(name, context)
pygpu_activated = True
if config.print_active_device:
print("Using device %s: %s" % (dev, context.devname), file=sys.stderr)
# remember the active device
init_dev.device = dev
init_dev.device = None
print("Mapped name %s to device %s: %s" % (name, dev, context.devname),
file=sys.stderr)
if pygpu:
try:
......
......@@ -19,7 +19,7 @@ try:
except ImportError:
pass
from .type import GpuArrayType
from .type import GpuArrayType, gpu_context_type, get_context
from .fp16_help import write_w
......
......@@ -14,14 +14,73 @@ try:
except ImportError:
pass
_context_reg = {}
def reg_context(name, ctx):
"""
Register a context by mapping it to a name.
The context must be of type `GpuContext` and the name can be
anything hashable (but is usually a string). Only one context can
be registered per name and the second registration for a given
name will raise an error.
Parameters
----------
name : hashable object
Name to associate the context with (usually a string)
ctx : GpuContext
Context instance
"""
if name in _context_reg:
raise ValueError("context name %s is already defined" % (name,))
if not isinstance(ctx, gpuarray.GpuContext):
raise TypeError("context is not GpuContext")
_context_reg[name] = ctx
def get_context(name):
"""
Retrive the context associated with a name.
Return the context object mapped to `ref` that was previously
register through :func:`reg_context`. Trying to get the context
for an unregistered `ref` will raise a exception.
Parameters
----------
name : hashable object
Name associated with the context we want (usually a string)
"""
if not name in _context_reg:
raise ValueError("context name %s not defined" % (name,))
return _context_reg[name]
# Private method
def _name_for_ctx(ctx):
for k, v in _context_reg:
if v == ctx:
return k
raise ValueError('context is not registered')
# This is a private method for use by the tests only
def _unreg_context(name):
del _context_reg[name]
class GpuArrayType(Type):
def __init__(self, dtype, broadcastable, name=None):
def __init__(self, dtype, broadcastable, context_name=None, name=None):
# In case this was not provided and no global value is available
self.dtype = str(dtype)
self.broadcastable = tuple(bool(b) for b in broadcastable)
self.ndim = len(self.broadcastable)
self.name = name
self.context_name = context_name
try:
self.typecode = gpuarray.dtype_to_typecode(self.dtype)
except gpuarray.GpuArrayException:
......@@ -34,10 +93,16 @@ class GpuArrayType(Type):
if broadcastable is None:
broadcastable = self.broadcastable
return self.__class__(dtype=dtype, broadcastable=broadcastable,
name=self.name)
context_name=self.context_name, name=self.name)
# This is a property to keep the type pickleable
@property
def context(self):
return get_context(self.context_name)
def __repr__(self):
return "GpuArrayType(%s, %s)" % (self.dtype, self.broadcastable)
return "GpuArrayType<%s>(%s, %s)" % (self.context_name, self.dtype,
self.broadcastable)
def filter(self, data, strict=False, allow_downcast=None):
if (isinstance(data, gpuarray.GpuArray) and
......@@ -54,25 +119,28 @@ class GpuArrayType(Type):
"got %d (dtype %s)." %
(self, self.typecode, self.dtype,
data.typecode, str(data.dtype)))
if self.context != data.context:
raise TypeError("data context does not match type context")
# fallthrough to ndim check
elif (allow_downcast or
(allow_downcast is None and
type(data) == float and
self.dtype == config.floatX)):
data = gpuarray.array(data, dtype=self.typecode, copy=False,
ndmin=len(self.broadcastable))
ndmin=len(self.broadcastable),
context=self.context)
else:
if not hasattr(data, 'dtype'):
# This is to convert objects that don't have a dtype
# (like lists). We anticipate that the type below
# will match and we pass copy=False so it won't make a
# second object on the GPU.
data = gpuarray.array(data, copy=False)
data = gpuarray.array(data, copy=False, context=self.context)
up_dtype = scalar.upcast(self.dtype, data.dtype)
if up_dtype == self.dtype:
data = gpuarray.array(data, dtype=self.dtype,
copy=False)
data = gpuarray.array(data, dtype=self.dtype, copy=False,
context=self.context)
else:
raise TypeError("%s cannot store a value of dtype %s "
"without risking loss of precision." %
......@@ -189,7 +257,8 @@ class GpuArrayType(Type):
return pygpu.gpuarray.may_share_memory(a, b)
def value_zeros(self, shape):
return pygpu.gpuarray.zeros(shape, dtype=self.typecode)
return pygpu.gpuarray.zeros(shape, dtype=self.typecode,
context=self.context)
def make_variable(self, name=None):
return self.Variable(self, name=name)
......@@ -197,19 +266,22 @@ class GpuArrayType(Type):
def __eq__(self, other):
return (type(self) == type(other) and
self.typecode == other.typecode and
self.broadcastable == other.broadcastable)
self.broadcastable == other.broadcastable and
self.context_name == other.context_name)
def convert_variable(self, var):
vt = var.type
if (type(self) == type(vt) and
self.typecode == vt.typecode and
self.ndim == vt.ndim and
self.context_name == vt.context_name and
all(sb == ob or ob for sb, ob in zip(self.broadcastable,
vt.broadcastable))):
return theano.tensor.patternbroadcast(var, self.broadcastable)
def __hash__(self):
return (hash(self.typecode) ^ hash(self.broadcastable))
return hash((type(self), self.typecode, self.broadcastable,
self.context_name))
def dtype_specs(self):
"""
......@@ -370,7 +442,8 @@ class GpuArraySharedVariable(_operators, SharedVariable):
def set_value(self, value, borrow=False):
if isinstance(value, pygpu.gpuarray.GpuArray):
value = pygpu.gpuarray.array(value, copy=(not borrow))
value = pygpu.gpuarray.array(value, copy=(not borrow),
context=self.type.context)
self.container.value = value
def __getitem__(self, *args):
......@@ -393,7 +466,8 @@ def gpuarray_shared_constructor(value, name=None, strict=False,
if broadcastable is None:
broadcastable = (False,) * value.ndim
type = GpuArrayType(value.dtype, broadcastable)
deviceval = pygpu.gpuarray.array(value, copy=(not borrow))
deviceval = pygpu.gpuarray.array(value, copy=(not borrow),
context=type.context)
return GpuArraySharedVariable(type=type, value=deviceval, name=name,
strict=strict)
......@@ -485,3 +559,63 @@ theano.compile.register_specify_shape_c_code(
""",
version=1,
c_support_code_apply='#include <numpy_compat.h>')
class GpuContextType(Type):
def filter(self, data, strict=False, allow_downcast=None):
if not isinstance(data, gpuarray.GpuContext):
raise TypeError('context is not a GpuContext')
return data
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return hash(type(self))
@staticmethod
def values_eq(a, b):
return a == b
def c_declare(self, name, sub, check_input=True):
return "PyGpuContextObject *%s;" % (name,)
def c_init(self, name, sub):
return "%s = NULL;" % (name,)
def c_extract(self, name, sub, check_input=True):
if check_input:
res = """
if (!PyObject_TypeCheck(py_%(name)s, &PyGpuContextType)) {
PyErr_SetString(PyExc_TypeError, "expected a GpuContext");
%(fail)s
}
""" % dict(name=name, fail=sub['fail'])
else:
res = ""
return res + """
%(name)s = (PyGpuContextObject *)py_%(name)s;
Py_INCREF(%(name)s);
""" % dict(name=name)
def c_cleanup(self, name, sub):
return "Py_XDECREF(%(name)s); %(name)s = NULL;" % dict(name=name)
# c_sync is intentionally not declared to prevent normal usage
def c_init_code(self):
return ['import_pygpu__gpuarray();']
def c_headers(self):
return ['<gpuarray_api.h>']
def c_header_dirs(self):
return [pygpu.get_include()]
def c_code_cache_version(self):
ver = pygpu.gpuarray.api_version()
return (0, ver[0])
# Variable, Contstant, ... not declared
gpu_context_type = GpuContextType()
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论