提交 6c3d8b43 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5789 from shawntan/issue-2763

Changes to GpuEye to enable `k` offset parameter.
...@@ -1583,9 +1583,7 @@ class GpuEye(GpuKernelBase, Op): ...@@ -1583,9 +1583,7 @@ class GpuEye(GpuKernelBase, Op):
broadcastable=(False, False), broadcastable=(False, False),
context_name=self.context_name) context_name=self.context_name)
# k != 0 isn't implemented on the GPU yet. return Apply(self, [n, m, k], [otype()])
assert tensor.get_scalar_constant_value(k) == 0
return Apply(self, [n, m], [otype()])
def infer_shape(self, node, in_shapes): def infer_shape(self, node, in_shapes):
out_shape = [node.inputs[0], node.inputs[1]] out_shape = [node.inputs[0], node.inputs[1]]
...@@ -1597,21 +1595,28 @@ class GpuEye(GpuKernelBase, Op): ...@@ -1597,21 +1595,28 @@ class GpuEye(GpuKernelBase, Op):
def gpu_kernels(self, node, name): def gpu_kernels(self, node, name):
code = """ code = """
KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) { KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m, ga_ssize k) {
ga_size nb = n < m ? n : m; ga_ssize coff = max(k, (ga_ssize) 0);
ga_ssize roff = -min(k, (ga_ssize) 0);
ga_size nb = (ga_size) min(n - roff, m - coff);
for (ga_size i = LID_0; i < nb; i += LDIM_0) { for (ga_size i = LID_0; i < nb; i += LDIM_0) {
a[i*m + i] = %(write_a)s(1); a[(i + roff)*m + i + coff] = %(write_a)s(1);
} }
}""" % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype), }""" % dict(ctype=pygpu.gpuarray.dtype_to_ctype(self.dtype),
name=name, write_a=write_w(self.dtype)) name=name, write_a=write_w(self.dtype))
return [Kernel( return [Kernel(
code=code, name="eye", code=code, name="eye",
params=[gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE], params=[gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SIZE, gpuarray.SSIZE],
flags=Kernel.get_flags(self.dtype), flags=Kernel.get_flags(self.dtype),
objvar='k_eye_' + name)] objvar='k_eye_' + name)]
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
n, m = inp if len(inp) == 2:
n, m = inp
k = 0
elif len(inp) == 3:
n, m, k = inp
z, = out z, = out
fail = sub['fail'] fail = sub['fail']
ctx = sub['params'] ctx = sub['params']
...@@ -1621,10 +1626,15 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) { ...@@ -1621,10 +1626,15 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
s = """ s = """
size_t dims[2] = {0, 0}; size_t dims[2] = {0, 0};
size_t ls, gs; size_t ls, gs;
ssize_t k;
size_t col_off;
size_t row_off;
int err; int err;
dims[0] = ((dtype_%(n)s*)PyArray_DATA(%(n)s))[0]; dims[0] = ((dtype_%(n)s*)PyArray_DATA(%(n)s))[0];
dims[1] = ((dtype_%(m)s*)PyArray_DATA(%(m)s))[0]; dims[1] = ((dtype_%(m)s*)PyArray_DATA(%(m)s))[0];
k = ((dtype_%(k)s*)PyArray_DATA(%(k)s))[0];
Py_CLEAR(%(z)s); Py_CLEAR(%(z)s);
%(z)s = pygpu_zeros(2, dims, %(z)s = pygpu_zeros(2, dims,
...@@ -1637,13 +1647,17 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) { ...@@ -1637,13 +1647,17 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
ls = 1; ls = 1;
gs = 256; gs = 256;
err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, dims[0], dims[1]); col_off = (size_t) (k > 0?k:0);
if (err != GA_NO_ERROR) { row_off = (size_t) (k < 0?-k:0);
PyErr_Format(PyExc_RuntimeError, if (row_off < dims[0] && col_off < dims[1]) {
"gpuarray error: kEye: %%s. n%%lu, m=%%lu.", err = eye_call(1, &gs, &ls, 0, %(z)s->ga.data, dims[0], dims[1], k);
GpuKernel_error(&%(kname)s, err), if (err != GA_NO_ERROR) {
(unsigned long)dims[0], (unsigned long)dims[1]); PyErr_Format(PyExc_RuntimeError,
%(fail)s; "gpuarray error: kEye: %%s. n%%lu, m=%%lu.",
GpuKernel_error(&%(kname)s, err),
(unsigned long)dims[0], (unsigned long)dims[1]);
%(fail)s;
}
} }
if(%(sync)d) if(%(sync)d)
...@@ -1653,4 +1667,4 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) { ...@@ -1653,4 +1667,4 @@ KERNEL void eye(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
return s return s
def c_code_cache_version(self): def c_code_cache_version(self):
return (6,) return (7,)
...@@ -392,7 +392,7 @@ def test_gpujoin_gpualloc(): ...@@ -392,7 +392,7 @@ def test_gpujoin_gpualloc():
def test_gpueye(): def test_gpueye():
def check(dtype, N, M_=None): def check(dtype, N, M_=None, k=0):
# Theano does not accept None as a tensor. # Theano does not accept None as a tensor.
# So we must use a real value. # So we must use a real value.
M = M_ M = M_
...@@ -402,13 +402,14 @@ def test_gpueye(): ...@@ -402,13 +402,14 @@ def test_gpueye():
M = N M = N
N_symb = T.iscalar() N_symb = T.iscalar()
M_symb = T.iscalar() M_symb = T.iscalar()
k_symb = np.asarray(0) k_symb = T.iscalar()
out = T.eye(N_symb, M_symb, k_symb, dtype=dtype) out = T.eye(N_symb, M_symb, k_symb, dtype=dtype) + np.array(1).astype(dtype)
f = theano.function([N_symb, M_symb], f = theano.function([N_symb, M_symb, k_symb],
T.stack(out), out,
mode=mode_with_gpu) mode=mode_with_gpu)
result = np.asarray(f(N, M))
assert np.allclose(result, np.eye(N, M_, dtype=dtype)) result = np.asarray(f(N, M, k)) - np.array(1).astype(dtype)
assert np.allclose(result, np.eye(N, M_, k, dtype=dtype))
assert result.dtype == np.dtype(dtype) assert result.dtype == np.dtype(dtype)
assert any([isinstance(node.op, GpuEye) assert any([isinstance(node.op, GpuEye)
for node in f.maker.fgraph.toposort()]) for node in f.maker.fgraph.toposort()])
...@@ -418,6 +419,22 @@ def test_gpueye(): ...@@ -418,6 +419,22 @@ def test_gpueye():
# M != N, k = 0 # M != N, k = 0
yield check, dtype, 3, 5 yield check, dtype, 3, 5
yield check, dtype, 5, 3 yield check, dtype, 5, 3
# N == M, k != 0
yield check, dtype, 3, 3, 1
yield check, dtype, 3, 3, -1
# N < M, k != 0
yield check, dtype, 3, 5, 1
yield check, dtype, 3, 5, -1
# N > M, k != 0
yield check, dtype, 5, 3, 1
yield check, dtype, 5, 3, -1
# k > M, -k > N, k > M, k > N
yield check, dtype, 5, 3, 3
yield check, dtype, 3, 5, 3
yield check, dtype, 5, 3, -3
yield check, dtype, 3, 5, -3
yield check, dtype, 5, 3, 6
yield check, dtype, 3, 5, -6
def test_hostfromgpu_shape_i(): def test_hostfromgpu_shape_i():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论