提交 be375941 authored 作者: Frederic Bastien's avatar Frederic Bastien

Speed up choice with replace on the GPU.

上级 bb27565a
......@@ -392,9 +392,12 @@ KERNEL void k_multi_warp_multinomial_wor(
PyErr_Format(PyExc_ValueError, "unis.shape[0] != pvals.shape[0] * n");
%(fail)s
}
pvals_copy = pygpu_copy(pvals, GA_C_ORDER);
if (! %(replace)s) {
pvals_copy = pygpu_copy(pvals, GA_C_ORDER);
} else {
pvals_copy = pvals;
Py_INCREF(pvals_copy);
}
dims[0] = n_samples;
dims[1] = PyGpuArray_DIMS(pvals)[0];
......@@ -474,7 +477,7 @@ KERNEL void k_multi_warp_multinomial_wor(
return s
def c_code_cache_version(self):
return (6,)
return (7,)
@register_opt('fast_compile')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论