提交 cd18dd6b authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Register the new type with OutputGuard and DeepCopy.

上级 ea33f222
...@@ -14,6 +14,7 @@ try: ...@@ -14,6 +14,7 @@ try:
except ImportError: except ImportError:
pass pass
class GpuArrayType(Type): class GpuArrayType(Type):
def __init__(self, dtype, broadcastable, name=None): def __init__(self, dtype, broadcastable, name=None):
# In case this was not provided and no global value is available # In case this was not provided and no global value is available
...@@ -220,3 +221,24 @@ def gpuarray_shared_constructor(value, name=None, strict=False, ...@@ -220,3 +221,24 @@ def gpuarray_shared_constructor(value, name=None, strict=False,
deviceval = pygpu.gpuarray.array(value, copy=(not borrow)) deviceval = pygpu.gpuarray.array(value, copy=(not borrow))
return GpuArraySharedVariable(type=type, value=deviceval, name=name, return GpuArraySharedVariable(type=type, value=deviceval, name=name,
strict=strict) strict=strict)
theano.compile.mode.register_OutputGuard_c_code(GpuArrayType)
theano.compile.function_module.register_DeepCopyOp_c_code(GpuArrayType, """
Py_XDECREF(%(oname)s);
%(oname)s = new_GpuArray(GpuArrayType, GpuArray_default_context);
if (!%(oname)s) { %(fail)s }
int err;
err = GpuArray_empty(&%(oname)s.ga, %(oname)s->context->ops);
if (err != GA_NO_ERROR) {
PyErr_SetString(PyExc_MemoryError, "Could not allocate new array");
%(fail)s
}
err = GpuArray_copy(&%(oname)s.ga, &%(iname)s.ga, GA_ANY_ORDER);
if (err != GA_NO_ERROR) {
PyErr_SetString(PyExc_RuntimeError, "Error during copy");
%(fail)s
}
""")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论