提交 ef8d7053 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

More accurate equality check for RandomState's state.

上级 900146e4
......@@ -34,13 +34,23 @@ class RandomStateType(gof.Type):
def values_eq(self, a, b):
sa = a.get_state()
sb = b.get_state()
for aa, bb in zip(sa, sb):
if isinstance(aa, numpy.ndarray):
if not numpy.all(aa == bb):
return False
else:
if not aa == bb:
return False
# Should always be the string 'MT19937'
if sa[0] != sb[0]:
return False
# 1-D array of 624 unsigned integer keys
if not numpy.all(sa[1] == sb[1]):
return False
# integer "pos" representing the position in the array
if sa[2] != sb[2]:
return False
# integer "has_gauss"
if sa[3] != sb[3]:
return False
# float "cached_gaussian".
# /!\ It is not initialized if has_gauss == 0
if sa[3] != 0:
if sa[4] != sb[4]:
return False
return True
random_state_type = RandomStateType()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论