提交 b6ee7ac3 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make GpuEye ok with offsets.

上级 6e921145
......@@ -1630,7 +1630,9 @@ class GpuEye(GpuKernelBase, Op):
def gpu_kernels(self, node, name):
code = """
KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size a_off,
ga_size n, ga_size m, ga_ssize k) {
a = (GLOBAL_MEM %(ctype)s *)(((char *)a) + a_off);
ga_ssize coff = max(k, (ga_ssize) 0);
ga_ssize roff = -min(k, (ga_ssize) 0);
ga_size nb = (ga_size) min(n - roff, m - coff);
......@@ -1641,7 +1643,8 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
name=name, write_a=write_w(self.dtype))
return [Kernel(
code=code, name="eye",
params=[gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE, gpuarray.SSIZE],
params=[gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE,
gpuarray.SIZE, gpuarray.SSIZE],
flags=Kernel.get_flags(self.dtype),
objvar='k_eye_' + name)]
......@@ -1685,7 +1688,8 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
col_off = (size_t) (k > 0?k:0);
row_off = (size_t) (k < 0?-k:0);
if (row_off < dims[0] && col_off < dims[1]) {
err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, dims[0], dims[1], k);
err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, %(z)s->ga.offset,
dims[0], dims[1], k);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"gpuarray error: kEye: %%s. n%%lu, m=%%lu.",
......@@ -1702,4 +1706,4 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
return s
def c_code_cache_version(self):
return (7,)
return (8,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论