提交 7786f27e authored 作者: James Bergstra's avatar James Bergstra

multinomial - workaround for numpy sampling bug

上级 8c256357
......@@ -607,8 +607,28 @@ def multinomial_helper(random_state, n, pvals, size):
out = numpy.ndarray(out_size)
broadcast_ind = _generate_broadcasting_indices(size, n.shape, pvals.shape[:-1])
# Iterate over these indices, drawing from one multinomial at a time from numpy
assert pvals.min() >= 0
for mi, ni, pi in zip(*broadcast_ind):
out[mi] = random_state.multinomial(n=n[ni], pvals=pvals[pi])
pvi = pvals[pi]
# This might someday be fixed upstream
# Currently numpy raises an exception in this method if the sum
# of probabilities meets or exceeds 1.0.
# In perfect arithmetic this would be correct, but in float32 or
# float64 it is too strict.
pisum = numpy.sum(pvi)
if 1.0 < pisum < 1.0+1e-5:#correct if we went a little over
# because mtrand.pyx has a ValueError that will trigger if
# sum(pvals[:-1]) > 1.0
pvi = pvi * (1.0 - 5e-5)
#pvi = pvi * .9
pisum = numpy.sum(pvi)
elif pvi[-1]<5e-5: #will this even work?
pvi = pvi * (1.0 - 5e-5)
pisum = numpy.sum(pvi)
assert pisum<=1.0, pisum
out[mi] = random_state.multinomial(n=n[ni],
pvals=pvi.astype('float64'))
return out
def multinomial(random_state, size=None, n=1, pvals=[0.5, 0.5],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论