提交 85b28c25 authored 作者: James Bergstra's avatar James Bergstra

rng_mrg - use numpy.seterr to suppress overflow warnings

上级 9f5955d0
...@@ -31,11 +31,15 @@ def mulmod(a, b, c, m): ...@@ -31,11 +31,15 @@ def mulmod(a, b, c, m):
def matVecModM(A, s, m): def matVecModM(A, s, m):
# return (A * s) % m # return (A * s) % m
x = numpy.zeros_like(s) err_orig = numpy.seterr(over='ignore')
for i in xrange(len(x)): try:
for j in xrange(len(s)): x = numpy.zeros_like(s)
x[i] = mulmod(A[i][j], s[j], x[i], m) for i in xrange(len(x)):
return x for j in xrange(len(s)):
x[i] = mulmod(A[i][j], s[j], x[i], m)
return x
finally:
numpy.seterr(**err_orig)
def multMatVect(v, A, m1, B, m2): def multMatVect(v, A, m1, B, m2):
#multiply the first half of v by A with a modulo of m1 #multiply the first half of v by A with a modulo of m1
...@@ -81,54 +85,58 @@ def ff_2p72(rstate): ...@@ -81,54 +85,58 @@ def ff_2p72(rstate):
return multMatVect(rstate, A1p72, M1, A2p72, M2) return multMatVect(rstate, A1p72, M1, A2p72, M2)
def mrg_next_value(rstate, new_rstate): def mrg_next_value(rstate, new_rstate):
x11, x12, x13, x21, x22, x23 = rstate err_orig = numpy.seterr(over='ignore')
assert type(x11) == numpy.int32 try:
x11, x12, x13, x21, x22, x23 = rstate
i0, i7, i9, i15, i16, i22, i24 = [numpy.int32(i) assert type(x11) == numpy.int32
for i in (0,7, 9, 15, 16, 22, 24)]
i0, i7, i9, i15, i16, i22, i24 = [numpy.int32(i)
#first component for i in (0,7, 9, 15, 16, 22, 24)]
y1 = (((x12 & MASK12) << i22) + (x12 >> i9)
+ ((x13 & MASK13) << i7) + (x13 >> i24)) #first component
y1 = (((x12 & MASK12) << i22) + (x12 >> i9)
assert type(y1) == numpy.int32 + ((x13 & MASK13) << i7) + (x13 >> i24))
if (y1 < 0 or y1 >= M1): #must also check overflow
y1 -= M1; assert type(y1) == numpy.int32
y1 += x13; if (y1 < 0 or y1 >= M1): #must also check overflow
if (y1 < 0 or y1 >= M1): y1 -= M1;
y1 -= M1; y1 += x13;
if (y1 < 0 or y1 >= M1):
x13 = x12; y1 -= M1;
x12 = x11;
x11 = y1; x13 = x12;
x12 = x11;
#second component x11 = y1;
y1 = ((x21 & MASK2) << i15) + (MULT2 * (x21 >> i16));
assert type(y1) == numpy.int32 #second component
if (y1 < 0 or y1 >= M2): y1 = ((x21 & MASK2) << i15) + (MULT2 * (x21 >> i16));
y1 -= M2; assert type(y1) == numpy.int32
y2 = ((x23 & MASK2) << i15) + (MULT2 * (x23 >> i16)); if (y1 < 0 or y1 >= M2):
assert type(y2) == numpy.int32 y1 -= M2;
if (y2 < 0 or y2 >= M2): y2 = ((x23 & MASK2) << i15) + (MULT2 * (x23 >> i16));
y2 -= M2; assert type(y2) == numpy.int32
y2 += x23; if (y2 < 0 or y2 >= M2):
if (y2 < 0 or y2 >= M2): y2 -= M2;
y2 -= M2; y2 += x23;
y2 += y1; if (y2 < 0 or y2 >= M2):
if (y2 < 0 or y2 >= M2): y2 -= M2;
y2 -= M2; y2 += y1;
if (y2 < 0 or y2 >= M2):
x23 = x22; y2 -= M2;
x22 = x21;
x21 = y2; x23 = x22;
x22 = x21;
# Must never return either 0 or M1+1 x21 = y2;
new_rstate[...] = [x11, x12, x13, x21, x22, x23]
assert new_rstate.dtype == numpy.int32 # Must never return either 0 or M1+1
if (x11 <= x21): new_rstate[...] = [x11, x12, x13, x21, x22, x23]
return (x11 - x21 + M1) * NORM assert new_rstate.dtype == numpy.int32
else: if (x11 <= x21):
return (x11 - x21) * NORM return (x11 - x21 + M1) * NORM
else:
return (x11 - x21) * NORM
finally:
numpy.seterr(**err_orig)
class mrg_uniform_base(Op): class mrg_uniform_base(Op):
def __init__(self, output_type, inplace=False): def __init__(self, output_type, inplace=False):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论