提交 4b00bd50 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix mrg(use_cuda=True).binomial(dtype=float32) to work on the gpu event if floatX is not float32

上级 a51132f6
......@@ -666,7 +666,10 @@ class MRG_RandomStreams(object):
def binomial(self, size=None, n=1, p=0.5, ndim=None, dtype='int64'):
if n == 1:
return cast(self.uniform(size=size) < p, dtype)
if dtype=='float32' and self.use_cuda:
return cast(self.uniform(size=size, dtype=dtype) < p, dtype)
else:
return cast(self.uniform(size=size) < p, dtype)
else:
raise NotImplementedError("MRG_RandomStreams.binomial with n > 1")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论