提交 8c414497 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix handling of errors in the Set case of GpuIncSubtensor.

上级 589b5926
...@@ -336,12 +336,23 @@ class GpuIncSubtensor(IncSubtensor): ...@@ -336,12 +336,23 @@ class GpuIncSubtensor(IncSubtensor):
C code expression to copy source into view, and 0 on success. C code expression to copy source into view, and 0 on success.
""" """
return """GpuArray_setarray(&%(view)s->ga, &%(source)s->ga)""" % locals() return """sub_setarray(&%(view)s->ga, &%(source)s->ga)""" % locals()
def c_headers(self): def c_headers(self):
return ['<numpy_compat.h>', '<gpuarray/error.h>', '<gpuarray/array.h>', return ['<numpy_compat.h>', '<gpuarray/error.h>', '<gpuarray/array.h>',
'<gpuarray/elemwise.h>'] '<gpuarray/elemwise.h>']
def c_support_code(self):
return """
int sub_setarray(GpuArray *dst, GpuArray *src) {
int err;
err = GpuArray_setarray(dst, src);
if (err != GA_NO_ERROR)
PyErr_SetString(PyExc_RuntimeError, "setarray failed");
return err;
}
"""
def c_support_code_struct(self, node, nodename): def c_support_code_struct(self, node, nodename):
return "\nGpuElemwise *iadd;\n" return "\nGpuElemwise *iadd;\n"
...@@ -383,7 +394,7 @@ class GpuIncSubtensor(IncSubtensor): ...@@ -383,7 +394,7 @@ class GpuIncSubtensor(IncSubtensor):
parent_version = super(GpuIncSubtensor, self).c_code_cache_version() parent_version = super(GpuIncSubtensor, self).c_code_cache_version()
if not parent_version: if not parent_version:
return return
return parent_version + (7,) return parent_version + (8,)
class GpuAdvancedSubtensor1(HideC, tensor.AdvancedSubtensor1): class GpuAdvancedSubtensor1(HideC, tensor.AdvancedSubtensor1):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论