提交 7121c2fb authored 作者: Shawn Tan's avatar Shawn Tan

Moved checks out of GPU, and additional tests.

上级 1d2aa9a4
......@@ -1612,13 +1612,11 @@ class GpuEye(GpuKernelBase, Op):
def gpu_kernels(self, node, name):
code = """
KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
ga_ssize col_off = max(k, (ga_ssize) 0);
ga_ssize row_off = -min(k, (ga_ssize) 0);
if (row_off < n && col_off < m) {
ga_size nb = (ga_size) min(n - row_off, m - col_off);
for (ga_size i = LID_0; i < nb; i += LDIM_0) {
a[(i + row_off)*m + i + col_off] = %(write_a)s(1);
}
ga_ssize coff = max(k, (ga_ssize) 0);
ga_ssize roff = -min(k, (ga_ssize) 0);
ga_size nb = (ga_size) min(n - roff, m - coff);
for (ga_size i = LID_0; i < nb; i += LDIM_0) {
a[(i + roff)*m + i + coff] = %(write_a)s(1);
}
}""" % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype),
name=name, write_a=write_w(self.dtype))
......@@ -1640,6 +1638,8 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
size_t dims[2] = {0, 0};
size_t ls, gs;
ssize_t k;
size_t col_off;
size_t row_off;
int err;
dims[0] = ((dtype_%(n)s*)PyArray_DATA(%(n)s))[0];
......@@ -1658,7 +1658,12 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
ls = 1;
gs = 256;
err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, dims[0], dims[1], k);
col_off = (size_t) (k > 0?k:0);
row_off = (size_t) (k < 0?-k:0);
if (row_off < dims[0] && col_off < dims[1]) {
err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, dims[0], dims[1], k);
}
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"gpuarray error: kEye: %%s. n%%lu, m=%%lu.",
......@@ -1674,4 +1679,5 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
return s
def c_code_cache_version(self):
return (7,)
# return (7,)
return
......@@ -428,6 +428,13 @@ def test_gpueye():
# N > M, k != 0
yield check, dtype, 5, 3, 1
yield check, dtype, 5, 3, -1
# k > M, -k > N, k > M, k > N
yield check, dtype, 5, 3, 3
yield check, dtype, 3, 5, 3
yield check, dtype, 5, 3, -3
yield check, dtype, 3, 5, -3
yield check, dtype, 5, 3, 6
yield check, dtype, 3, 5, -6
def test_hostfromgpu_shape_i():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论