提交 7121c2fb authored 作者: Shawn Tan's avatar Shawn Tan

Moved checks out of GPU, and additional tests.

上级 1d2aa9a4
...@@ -1612,13 +1612,11 @@ class GpuEye(GpuKernelBase, Op): ...@@ -1612,13 +1612,11 @@ class GpuEye(GpuKernelBase, Op):
def gpu_kernels(self, node, name): def gpu_kernels(self, node, name):
code = """ code = """
KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) { KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
ga_ssize col_off = max(k, (ga_ssize) 0); ga_ssize coff = max(k, (ga_ssize) 0);
ga_ssize row_off = -min(k, (ga_ssize) 0); ga_ssize roff = -min(k, (ga_ssize) 0);
if (row_off < n && col_off < m) { ga_size nb = (ga_size) min(n - roff, m - coff);
ga_size nb = (ga_size) min(n - row_off, m - col_off); for (ga_size i = LID_0; i < nb; i += LDIM_0) {
for (ga_size i = LID_0; i < nb; i += LDIM_0) { a[(i + roff)*m + i + coff] = %(write_a)s(1);
a[(i + row_off)*m + i + col_off] = %(write_a)s(1);
}
} }
}""" % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype), }""" % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype),
name=name, write_a=write_w(self.dtype)) name=name, write_a=write_w(self.dtype))
...@@ -1640,6 +1638,8 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) { ...@@ -1640,6 +1638,8 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
size_t dims[2] = {0, 0}; size_t dims[2] = {0, 0};
size_t ls, gs; size_t ls, gs;
ssize_t k; ssize_t k;
size_t col_off;
size_t row_off;
int err; int err;
dims[0] = ((dtype_%(n)s*)PyArray_DATA(%(n)s))[0]; dims[0] = ((dtype_%(n)s*)PyArray_DATA(%(n)s))[0];
...@@ -1658,7 +1658,12 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) { ...@@ -1658,7 +1658,12 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
ls = 1; ls = 1;
gs = 256; gs = 256;
err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, dims[0], dims[1], k); col_off = (size_t) (k > 0?k:0);
row_off = (size_t) (k < 0?-k:0);
if (row_off < dims[0] && col_off < dims[1]) {
err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, dims[0], dims[1], k);
}
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"gpuarray error: kEye: %%s. n%%lu, m=%%lu.", "gpuarray error: kEye: %%s. n%%lu, m=%%lu.",
...@@ -1674,4 +1679,5 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) { ...@@ -1674,4 +1679,5 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
return s return s
def c_code_cache_version(self): def c_code_cache_version(self):
return (7,) # return (7,)
return
...@@ -428,6 +428,13 @@ def test_gpueye(): ...@@ -428,6 +428,13 @@ def test_gpueye():
# N > M, k != 0 # N > M, k != 0
yield check, dtype, 5, 3, 1 yield check, dtype, 5, 3, 1
yield check, dtype, 5, 3, -1 yield check, dtype, 5, 3, -1
# k > M, -k > N, k > M, k > N
yield check, dtype, 5, 3, 3
yield check, dtype, 3, 5, 3
yield check, dtype, 5, 3, -3
yield check, dtype, 3, 5, -3
yield check, dtype, 5, 3, 6
yield check, dtype, 3, 5, -6
def test_hostfromgpu_shape_i(): def test_hostfromgpu_shape_i():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论