提交 8002842c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make GPU multinomial work on strided output

上级 488cef77
...@@ -147,19 +147,21 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp): ...@@ -147,19 +147,21 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
return Apply(self, [pvals, unis], [pvals.type()]) return Apply(self, [pvals, unis], [pvals.type()])
def c_code_cache_version(self): def c_code_cache_version(self):
return (7,) return (8,)
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
return """ return """
static __global__ void k_multi_warp_%(nodename)s( static __global__ void k_multi_warp_%(nodename)s(
const int nb_multi, const int nb_multi,
const int nb_outcomes, const int nb_outcomes,
const int pvals_row_strides,
const int pvals_col_strides,
const int unis_stride,
float * global_pvals, float * global_pvals,
const int pvals_row_stride,
const int pvals_col_stride,
float * global_unis, float * global_unis,
float * global_outs const int unis_stride,
float * global_outs,
const int outs_row_stride,
const int outs_col_stride
) )
{ {
// each thread takes care of one multinomial draw // each thread takes care of one multinomial draw
...@@ -174,7 +176,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp): ...@@ -174,7 +176,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
float current_out = 0.; float current_out = 0.;
if (!done) if (!done)
{ {
cummul += global_pvals[m * pvals_col_strides + n * pvals_row_strides]; cummul += global_pvals[m * pvals_col_stride + n * pvals_row_stride];
if (unis_n < cummul) if (unis_n < cummul)
{ {
current_out = 1.; current_out = 1.;
...@@ -182,7 +184,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp): ...@@ -182,7 +184,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
} }
} }
//write out transposed for speed. //write out transposed for speed.
global_outs[n + m * nb_multi] = current_out; global_outs[n * outs_col_stride + m * outs_row_stride] = current_out;
} }
} }
} }
...@@ -262,12 +264,14 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp): ...@@ -262,12 +264,14 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
k_multi_warp_%(name)s<<<n_blocks, n_threads, n_shared>>>( k_multi_warp_%(name)s<<<n_blocks, n_threads, n_shared>>>(
CudaNdarray_HOST_DIMS(%(z)s)[1], CudaNdarray_HOST_DIMS(%(z)s)[1],
CudaNdarray_HOST_DIMS(%(z)s)[0], CudaNdarray_HOST_DIMS(%(z)s)[0],
CudaNdarray_DEV_DATA(%(pvals)s),
CudaNdarray_HOST_STRIDES(%(pvals)s)[0], CudaNdarray_HOST_STRIDES(%(pvals)s)[0],
CudaNdarray_HOST_STRIDES(%(pvals)s)[1], CudaNdarray_HOST_STRIDES(%(pvals)s)[1],
CudaNdarray_HOST_STRIDES(%(unis)s)[0],
CudaNdarray_DEV_DATA(%(pvals)s),
CudaNdarray_DEV_DATA(%(unis)s), CudaNdarray_DEV_DATA(%(unis)s),
CudaNdarray_DEV_DATA(%(z)s) CudaNdarray_HOST_STRIDES(%(unis)s)[0],
CudaNdarray_DEV_DATA(%(z)s),
CudaNdarray_HOST_STRIDES(%(z)s)[0],
CudaNdarray_HOST_STRIDES(%(z)s)[1]
); );
CNDA_THREAD_SYNC; CNDA_THREAD_SYNC;
cudaError_t sts = cudaGetLastError(); cudaError_t sts = cudaGetLastError();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论