提交 8002842c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Make GPU multinomial work on strided output

上级 488cef77
......@@ -147,19 +147,21 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
return Apply(self, [pvals, unis], [pvals.type()])
def c_code_cache_version(self):
return (7,)
return (8,)
def c_support_code_apply(self, node, nodename):
return """
static __global__ void k_multi_warp_%(nodename)s(
const int nb_multi,
const int nb_outcomes,
const int pvals_row_strides,
const int pvals_col_strides,
const int unis_stride,
float * global_pvals,
const int pvals_row_stride,
const int pvals_col_stride,
float * global_unis,
float * global_outs
const int unis_stride,
float * global_outs,
const int outs_row_stride,
const int outs_col_stride
)
{
// each thread takes care of one multinomial draw
......@@ -174,7 +176,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
float current_out = 0.;
if (!done)
{
cummul += global_pvals[m * pvals_col_strides + n * pvals_row_strides];
cummul += global_pvals[m * pvals_col_stride + n * pvals_row_stride];
if (unis_n < cummul)
{
current_out = 1.;
......@@ -182,7 +184,7 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
}
}
//write out transposed for speed.
global_outs[n + m * nb_multi] = current_out;
global_outs[n * outs_col_stride + m * outs_row_stride] = current_out;
}
}
}
......@@ -262,12 +264,14 @@ class GpuMultinomialFromUniform(MultinomialFromUniform, GpuOp):
k_multi_warp_%(name)s<<<n_blocks, n_threads, n_shared>>>(
CudaNdarray_HOST_DIMS(%(z)s)[1],
CudaNdarray_HOST_DIMS(%(z)s)[0],
CudaNdarray_DEV_DATA(%(pvals)s),
CudaNdarray_HOST_STRIDES(%(pvals)s)[0],
CudaNdarray_HOST_STRIDES(%(pvals)s)[1],
CudaNdarray_HOST_STRIDES(%(unis)s)[0],
CudaNdarray_DEV_DATA(%(pvals)s),
CudaNdarray_DEV_DATA(%(unis)s),
CudaNdarray_DEV_DATA(%(z)s)
CudaNdarray_HOST_STRIDES(%(unis)s)[0],
CudaNdarray_DEV_DATA(%(z)s),
CudaNdarray_HOST_STRIDES(%(z)s)[0],
CudaNdarray_HOST_STRIDES(%(z)s)[1]
);
CNDA_THREAD_SYNC;
cudaError_t sts = cudaGetLastError();
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论