提交 b6ee7ac3 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make GpuEye ok with offsets.

上级 6e921145
...@@ -1630,7 +1630,9 @@ class GpuEye(GpuKernelBase, Op): ...@@ -1630,7 +1630,9 @@ class GpuEye(GpuKernelBase, Op):
def gpu_kernels(self, node, name): def gpu_kernels(self, node, name):
code = """ code = """
KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) { KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size a_off,
ga_size n, ga_size m, ga_ssize k) {
a = (GLOBAL_MEM %(ctype)s *)(((char *)a) + a_off);
ga_ssize coff = max(k, (ga_ssize) 0); ga_ssize coff = max(k, (ga_ssize) 0);
ga_ssize roff = -min(k, (ga_ssize) 0); ga_ssize roff = -min(k, (ga_ssize) 0);
ga_size nb = (ga_size) min(n - roff, m - coff); ga_size nb = (ga_size) min(n - roff, m - coff);
...@@ -1641,7 +1643,8 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) { ...@@ -1641,7 +1643,8 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
name=name, write_a=write_w(self.dtype)) name=name, write_a=write_w(self.dtype))
return [Kernel( return [Kernel(
code=code, name="eye", code=code, name="eye",
params=[gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE, gpuarray.SSIZE], params=[gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE,
gpuarray.SIZE, gpuarray.SSIZE],
flags=Kernel.get_flags(self.dtype), flags=Kernel.get_flags(self.dtype),
objvar='k_eye_' + name)] objvar='k_eye_' + name)]
...@@ -1685,7 +1688,8 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) { ...@@ -1685,7 +1688,8 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
col_off = (size_t) (k > 0?k:0); col_off = (size_t) (k > 0?k:0);
row_off = (size_t) (k < 0?-k:0); row_off = (size_t) (k < 0?-k:0);
if (row_off < dims[0] && col_off < dims[1]) { if (row_off < dims[0] && col_off < dims[1]) {
err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, dims[0], dims[1], k); err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, %(z)s->ga.offset,
dims[0], dims[1], k);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"gpuarray error: kEye: %%s. n%%lu, m=%%lu.", "gpuarray error: kEye: %%s. n%%lu, m=%%lu.",
...@@ -1702,4 +1706,4 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) { ...@@ -1702,4 +1706,4 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
return s return s
def c_code_cache_version(self): def c_code_cache_version(self):
return (7,) return (8,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论