提交 8b17154a authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Follow the change in call protocol in libgpuarray.

上级 6a8fa46f
...@@ -944,6 +944,7 @@ KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) { ...@@ -944,6 +944,7 @@ KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
kname = self.gpu_kernels(node, name)[0].objvar kname = self.gpu_kernels(node, name)[0].objvar
s = """ s = """
size_t dims[2] = {0, 0}; size_t dims[2] = {0, 0};
size_t ls, gs;
void *args[3]; void *args[3];
int err; int err;
...@@ -959,10 +960,12 @@ KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) { ...@@ -959,10 +960,12 @@ KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
%(fail)s %(fail)s
} }
args[0] = &%(z)s->ga; args[0] = %(z)s->ga.data;
args[1] = &dims[0]; args[1] = &dims[0];
args[2] = &dims[1]; args[2] = &dims[1];
err = GpuKernel_call(&%(kname)s, 0, 1, 256, args); ls = 1;
gs = 256;
err = GpuKernel_call(&%(kname)s, 1, &ls, &gs, 0, args);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"gpuarray error: kEye: %%s. n%%lu, m=%%lu.", "gpuarray error: kEye: %%s. n%%lu, m=%%lu.",
...@@ -978,4 +981,4 @@ KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) { ...@@ -978,4 +981,4 @@ KERNEL void k(GLOBAL_MEM %(ctype)s *a, ga_size n, ga_size m) {
return s return s
def c_code_cache_version(self): def c_code_cache_version(self):
return (3, self.GpuKernelBase_version) return (4, self.GpuKernelBase_version)
...@@ -2664,6 +2664,7 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype): ...@@ -2664,6 +2664,7 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype):
nd_out = node.outputs[0].ndim nd_out = node.outputs[0].ndim
code = """ code = """
size_t gs = 1; size_t gs = 1;
size_t ls;
unsigned int n = 1; unsigned int n = 1;
unsigned int proxy_dim[%(nd_in)s]; unsigned int proxy_dim[%(nd_in)s];
unsigned int proxy_off; unsigned int proxy_off;
...@@ -2727,7 +2728,7 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype): ...@@ -2727,7 +2728,7 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype):
# data in the proper type. # data in the proper type.
code += """ code += """
args[0] = &n; args[0] = &n;
args[1] = &tmp->ga; args[1] = tmp->ga.data;
""" % dict(output=output) """ % dict(output=output)
p = 2 p = 2
...@@ -2742,7 +2743,7 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype): ...@@ -2742,7 +2743,7 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype):
code += "gs *= %(input)s->ga.dimensions[%(i)s];" % dict(input=input, i=i) code += "gs *= %(input)s->ga.dimensions[%(i)s];" % dict(input=input, i=i)
code += """ code += """
args[%(p)s] = &%(input)s->ga; args[%(p)s] = %(input)s->ga.data;
proxy_off = %(input)s->ga.offset; proxy_off = %(input)s->ga.offset;
args[%(p)s+1] = &proxy_off; args[%(p)s+1] = &proxy_off;
""" % dict(p=p, input=input) """ % dict(p=p, input=input)
...@@ -2758,7 +2759,8 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype): ...@@ -2758,7 +2759,8 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype):
code += """ code += """
if (gs == 0) gs = 1; if (gs == 0) gs = 1;
n /= gs; n /= gs;
err = GpuKernel_call(&%(k_var)s, 0, %(ls)s, gs, args); ls = %(ls)s;
err = GpuKernel_call(&%(k_var)s, 1, &ls, &gs, 0, args);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, PyErr_Format(PyExc_RuntimeError,
"gpuarray error: GpuCAReduceCPY: %%s.", "gpuarray error: GpuCAReduceCPY: %%s.",
...@@ -2788,7 +2790,7 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype): ...@@ -2788,7 +2790,7 @@ class GpuCAReduceCPY(GpuKernelBase, HideC, CAReduceDtype):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, self.GpuKernelBase_version) return (1, self.GpuKernelBase_version)
def generate_kernel(self, node, odtype, redux): def generate_kernel(self, node, odtype, redux):
if isinstance(self.scalar_op, scalar.basic.Add): if isinstance(self.scalar_op, scalar.basic.Add):
......
...@@ -994,11 +994,18 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base): ...@@ -994,11 +994,18 @@ class GPUA_mrg_uniform(GpuKernelBase, mrg_uniform_base):
{ {
void *args[4]; void *args[4];
args[0] = &%(o_sample)s->ga; size_t ls = 0, gs = 0;
args[1] = &%(o_rstate)s->ga; args[0] = %(o_sample)s->ga.data;
args[1] = %(o_rstate)s->ga.data;
args[2] = &n_elements; args[2] = &n_elements;
args[3] = &n_streams; args[3] = &n_streams;
int err = GpuKernel_call(&%(kname)s, n_elements, 0, 0, args); int err = GpuKernel_sched(&%(kname)s, n_elements, &ls, &gs);
if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, "GpuKernel_sched: %%s\\n",
GpuKernel_error(&%(kname)s, err));
%(fail)s
}
err = GpuKernel_call(&%(kname)s, 1, &ls, &gs, 0, args);
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_Format(PyExc_RuntimeError, "GpuKernel_call: %%s\\n", PyErr_Format(PyExc_RuntimeError, "GpuKernel_call: %%s\\n",
GpuKernel_error(&%(kname)s, err)); GpuKernel_error(&%(kname)s, err));
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论