提交 213f1f02 authored 作者: Laurent Dinh's avatar Laurent Dinh

Fix arg order and optimization

上级 6469a825
...@@ -368,6 +368,7 @@ KERNEL void k_multi_warp_multinomial_wor( ...@@ -368,6 +368,7 @@ KERNEL void k_multi_warp_multinomial_wor(
def c_code(self, node, name, inp, outputs, sub): def c_code(self, node, name, inp, outputs, sub):
pvals, unis, n = inp pvals, unis, n = inp
out, = outputs out, = outputs
replace = int(self.replace)
fail = sub['fail'] fail = sub['fail']
ctx = sub['params'] ctx = sub['params']
sync = bool(config.gpuarray.sync) sync = bool(config.gpuarray.sync)
...@@ -470,7 +471,7 @@ KERNEL void k_multi_warp_multinomial_wor( ...@@ -470,7 +471,7 @@ KERNEL void k_multi_warp_multinomial_wor(
nb_threads2[1] = 1; nb_threads2[1] = 1;
// If we can't schedule enough threads parallelize the renormalization. // If we can't schedule enough threads parallelize the renormalization.
// I do this because we don't always use those extra threads. // I do this because we don't always use those extra threads.
if (nb_threads * nb_blocks < 2048) if ((nb_threads * nb_blocks < 2048) && %(replace)d )
nb_threads2[1] = 1024 / nb_threads; nb_threads2[1] = 1024 / nb_threads;
nb_blocks2[0] = nb_blocks; nb_blocks2[0] = nb_blocks;
......
...@@ -221,9 +221,9 @@ class ChoiceFromUniform(MultinomialFromUniform): ...@@ -221,9 +221,9 @@ class ChoiceFromUniform(MultinomialFromUniform):
__props__ = ("replace",) __props__ = ("replace",)
def __init__(self, replace=False, *args, **kwargs): def __init__(self, odtype, replace=False, *args, **kwargs):
self.replace = replace self.replace = replace
super(ChoiceFromUniform, self).__init__(*args, **kwargs) super(ChoiceFromUniform, self).__init__(odtype=odtype, *args, **kwargs)
def __setstate__(self, state): def __setstate__(self, state):
self.__dict__.update(state) self.__dict__.update(state)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论