提交 106ed858 authored 作者: Shawn Tan's avatar Shawn Tan

Changes to GpuEye to enable `k` offset parameter.

上级 0784ab7b
......@@ -1600,8 +1600,8 @@ class GpuEye(GpuKernelBase, Op):
context_name=self.context_name)
# k != 0 isn't implemented on the GPU yet.
assert tensor.get_scalar_constant_value(k) == 0
return Apply(self, [n, m], [otype()])
# assert tensor.get_scalar_constant_value(k) == 0
return Apply(self, [n, m, k], [otype()])
def infer_shape(self, node, in_shapes):
out_shape = [node.inputs[0], node.inputs[1]]
......@@ -1613,21 +1613,24 @@ class GpuEye(GpuKernelBase, Op):
def gpu_kernels(self, node, name):
code = """
KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
ga_size nb = n < m ? n : m;
KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_int k) {
ga_int col_off = max(k, (ga_int) 0);
ga_int row_off = -min(k, (ga_int) 0);
ga_size nb = (ga_size) min(n - row_off, m - col_off);
nb = max(nb, (ga_size) 0);
for (ga_size i = LID_0; i < nb; i += LDIM_0) {
a[i*m + i] = %(write_a)s(1);
a[(i + row_off)*m + i + col_off] = %(write_a)s(1);
}
}""" % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype),
name=name, write_a=write_w(self.dtype))
return [Kernel(
code=code, name="eye",
params=[gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE],
params=[gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE, np.int32],
flags=Kernel.get_flags(self.dtype),
objvar='k_eye_' + name)]
def c_code(self, node, name, inp, out, sub):
n, m = inp
n, m, k = inp
z, = out
fail = sub['fail']
ctx = sub['params']
......@@ -1637,10 +1640,13 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
s = """
size_t dims[2] = {0, 0};
size_t ls, gs;
int k;
int err;
dims[0] = ((dtype_%(n)s*)PyArray_DATA(%(n)s))[0];
dims[1] = ((dtype_%(m)s*)PyArray_DATA(%(m)s))[0];
k = ((dtype_%(k)s*)PyArray_DATA(%(k)s))[0];
Py_CLEAR(%(z)s);
%(z)s = pygpu_zeros(2, dims,
......@@ -1653,7 +1659,7 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
ls = 1;
gs = 256;
err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, dims[0], dims[1]);
err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, dims[0], dims[1], k);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError,
"gpuarray error: kEye: %%s. n%%lu, m=%%lu.",
......@@ -1669,4 +1675,4 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
return s
def c_code_cache_version(self):
return (6,)
return (7,)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论