提交 fe17c745 authored 作者: James Bergstra's avatar James Bergstra

further maxint-related corrections

上级 3ac2f252
......@@ -212,7 +212,7 @@ class RandomKit(SymbolicInputKit):
return out
def distribute(self, value, indices, containers):
rg = partial(numpy.random.RandomState(value).randint, 2**32)
rg = partial(numpy.random.RandomState(value).randint, 2**30)
elems = deque(zip(indices, containers))
i = 0
while elems:
......@@ -270,7 +270,7 @@ class RModule(compile.Module):
# and a list of corresponding gof.Container instances. In this
# situation it will reseed all the rngs using the containers
# associated to them.
c._rkit.kit.distribute(seedgen.random_integers(2**31),
c._rkit.kit.distribute(seedgen.random_integers(2**30),
xrange(len(inst2._rkit)), inst2._rkit)
else:
self._rkit.kit.distribute(seedgen.random_integers(2**31), xrange(len(inst._rkit)), inst._rkit)
self._rkit.kit.distribute(seedgen.random_integers(2**30), xrange(len(inst._rkit)), inst._rkit)
......@@ -484,7 +484,7 @@ def test_naacl_model(optimizer='fast_run'):
m.pretraining_update(*inputs)
s0, s1 = [str(i) for i in m.pretraining_update(*inputs)]
print s0, s1
if s0 + ' ' + s1 != '0.315775007436 0.132479386981':
if s0 + ' ' + s1 != '0.402187608584 0.0744508017774':
raise ValueError('pretraining update values do not match')
print 'FINETUNING GRAPH'
print 'SUPERVISED PHASE COSTS (%s)'%optimizer
......@@ -493,7 +493,7 @@ def test_naacl_model(optimizer='fast_run'):
m.finetuning_update(*(inputs + [targets]))
s0 = str(m.finetuning_update(*(inputs + [targets])))
print s0
if s0 != '15.8609933666':
if s0 != '15.6512776369':
raise ValueError('finetuning values do not match')
if __name__ == '__main__':
......
......@@ -50,15 +50,15 @@ def test_B():
#print m.f(N.ones(5))
#print m.f(N.ones(5))
#print m.f(N.ones(5))
rvals = ["0.0655889727823 0.566937256035 0.486897708861 0.939594224804 0.731948448071",
"0.407174827663 0.450046718267 0.454825370073 0.874814293401 0.828759935744",
"0.573194634066 0.746015418896 0.864696705461 0.8405810785 0.540268740918",
"0.924477905238 0.96687901023 0.306490321744 0.654349923901 0.789402591813",
"0.513182053208 0.0426565286449 0.0723651478047 0.454308519009 0.86151064181"]
rvals = ["0.74802375876 0.872308123517 0.294830748897 0.803123780003 0.6321109955",
"0.00168744844365 0.278638315678 0.725436793755 0.7788480779 0.629885140994",
"0.545561221664 0.0992011009108 0.847112593242 0.188015424144 0.158046201298",
"0.054382248842 0.563459168529 0.192757276954 0.360455221883 0.174805216702",
"0.961942907777 0.49657319422 0.0316111492826 0.0915054717012 0.195877184515"]
for i in xrange(5):
s = " ".join([str(n) for n in m.f(N.ones(5))])
print s
assert s == rvals[i]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论