提交 ef8d7053 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

More accurate equality check for RandomState's state.

上级 900146e4
...@@ -34,13 +34,23 @@ class RandomStateType(gof.Type): ...@@ -34,13 +34,23 @@ class RandomStateType(gof.Type):
def values_eq(self, a, b): def values_eq(self, a, b):
sa = a.get_state() sa = a.get_state()
sb = b.get_state() sb = b.get_state()
for aa, bb in zip(sa, sb): # Should always be the string 'MT19937'
if isinstance(aa, numpy.ndarray): if sa[0] != sb[0]:
if not numpy.all(aa == bb): return False
return False # 1-D array of 624 unsigned integer keys
else: if not numpy.all(sa[1] == sb[1]):
if not aa == bb: return False
return False # integer "pos" representing the position in the array
if sa[2] != sb[2]:
return False
# integer "has_gauss"
if sa[3] != sb[3]:
return False
# float "cached_gaussian".
# /!\ It is not initialized if has_gauss == 0
if sa[3] != 0:
if sa[4] != sb[4]:
return False
return True return True
random_state_type = RandomStateType() random_state_type = RandomStateType()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论