提交 11f98cbb authored 作者: Frederic Bastien's avatar Frederic Bastien 提交者: notoraptor

Use the new calling convention.

上级 c5194587
...@@ -185,12 +185,12 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -185,12 +185,12 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
axis = self.axis % ndim axis = self.axis % ndim
del(reordered_axes[axis]) del(reordered_axes[axis])
reordered_axes = [axis] + reordered_axes reordered_axes = [axis] + reordered_axes
dims = ''.join('(void*)(dims+%d), ' % i for i in reordered_axes[1:]) dims = ''.join('dims[%d], ' % i for i in reordered_axes[1:])
prep_output = '' prep_output = ''
if self.return_values: if self.return_values:
def_dvstrides = 'const ssize_t *dvstrides = PyGpuArray_STRIDES(%s)' % yv def_dvstrides = 'const ssize_t *dvstrides = PyGpuArray_STRIDES(%s)' % yv
params_dv = '(void*)((char*)(%s->ga.data), (%s->ga.offset)),\n' % (yv, yv) params_dv = '%s->ga.data, %s->ga.offset,\n' % (yv, yv)
params_dv += ''.join('(void*)(dvstrides+%d), ' % i for i in reordered_axes) params_dv += ''.join('dvstrides[%d], ' % i for i in reordered_axes)
prep_output += ''' prep_output += '''
if (0 != theano_prep_output( if (0 != theano_prep_output(
&%(yv)s, %(ndim)d, odims, &%(yv)s, %(ndim)d, odims,
...@@ -202,8 +202,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -202,8 +202,8 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
if self.return_indices: if self.return_indices:
def_distrides = 'const ssize_t *distrides = PyGpuArray_STRIDES(%s)' % yi def_distrides = 'const ssize_t *distrides = PyGpuArray_STRIDES(%s)' % yi
params_di = '(void*)((char*)(%s->ga.data), (%s->ga.offset)),\n' % (yi, yi) params_di = '%s->ga.data, %s->ga.offset,\n' % (yi, yi)
params_di += ''.join('(void*)(distrides+%d), ' % i for i in reordered_axes) params_di += ''.join('distrides[%d], ' % i for i in reordered_axes)
prep_output += ''' prep_output += '''
if (0 != theano_prep_output( if (0 != theano_prep_output(
&%(yi)s, %(ndim)d, odims, &%(yi)s, %(ndim)d, odims,
...@@ -212,7 +212,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -212,7 +212,7 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
}\n''' % locals() }\n''' % locals()
else: else:
def_distrides = params_di = '' def_distrides = params_di = ''
sstrides = ', '.join('(void*)(sstrides+%d)' % i for i in reordered_axes) sstrides = ', '.join('sstrides[%d]' % i for i in reordered_axes)
code = ''' code = '''
{ {
const ssize_t k_ = ((%(k_dtype)s*)(PyArray_DATA(%(k)s)))[0]; const ssize_t k_ = ((%(k_dtype)s*)(PyArray_DATA(%(k)s)))[0];
...@@ -256,31 +256,46 @@ class GpuTopKOp(GpuKernelBase, TopKOp): ...@@ -256,31 +256,46 @@ class GpuTopKOp(GpuKernelBase, TopKOp):
%(def_dvstrides)s; %(def_dvstrides)s;
%(def_distrides)s; %(def_distrides)s;
const ssize_t *sstrides = PyGpuArray_STRIDES(%(x)s); const ssize_t *sstrides = PyGpuArray_STRIDES(%(x)s);
void* args[] = {
%(dims)s
%(params_dv)s
%(params_di)s
(void*)(&k_),
(void*)((char*)(%(x)s->ga.data), (%(x)s->ga.offset)),
%(sstrides)s,
(void*)(dims+%(axis)d),
};
int err; int err;
if (dims[%(axis)d] > (1u << 31)) { if (dims[%(axis)d] > (1u << 31)) {
block_size = %(MAX_TPB)d; block_size = %(MAX_TPB)d;
err = GpuKernel_call( err = k_topk_dense_xlarge_call(
&k_topk_dense_xlarge%(nodename)s, 1, 1, &grid_size, &block_size, 0,
&grid_size, &block_size, 0, args); %(dims)s
%(params_dv)s
%(params_di)s
k_,
%(x)s->ga.data,
%(x)s->ga.offset,
%(sstrides)s,
dims[%(axis)d]
);
} else if (block_size > %(MAX_TPB)d) { } else if (block_size > %(MAX_TPB)d) {
block_size = %(MAX_TPB)d; block_size = %(MAX_TPB)d;
err = GpuKernel_call( err = k_topk_dense_large_call(
&k_topk_dense_large%(nodename)s, 1, 1, &grid_size, &block_size, 0,
&grid_size, &block_size, 0, args); %(dims)s
%(params_dv)s
%(params_di)s
k_,
%(x)s->ga.data,
%(x)s->ga.offset,
%(sstrides)s,
dims[%(axis)d]
);
} else { } else {
err = GpuKernel_call( err = k_topk_dense_call(
&k_topk_dense%(nodename)s, 1, 1, &grid_size, &block_size, 0,
&grid_size, &block_size, 0, args); %(dims)s
%(params_dv)s
%(params_di)s
k_,
%(x)s->ga.data,
%(x)s->ga.offset,
%(sstrides)s,
dims[%(axis)d]
);
} }
if (err != GA_NO_ERROR) { if (err != GA_NO_ERROR) {
PyErr_SetString( PyErr_SetString(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论