提交 399cff79 authored 作者: james@X40's avatar james@X40

fixes to make RandomFunction work on 32bit

上级 e68f24dc
...@@ -118,7 +118,7 @@ class RandomFunction(gof.Op): ...@@ -118,7 +118,7 @@ class RandomFunction(gof.Op):
shape = tensor.as_tensor(shape, ndim=1) shape = tensor.as_tensor(shape, ndim=1)
#print 'SHAPE TYPE', shape.type, tensor.lvector #print 'SHAPE TYPE', shape.type, tensor.lvector
assert shape.type.ndim == 1 assert shape.type.ndim == 1
assert shape.type.dtype == 'int64' assert (shape.type.dtype == 'int64') or (shape.type.dtype == 'int32')
if not isinstance(r.type, RandomStateType): if not isinstance(r.type, RandomStateType):
print >> sys.stderr, 'WARNING: RandomState instances should be in RandomStateType' print >> sys.stderr, 'WARNING: RandomState instances should be in RandomStateType'
if 0: if 0:
......
...@@ -128,9 +128,11 @@ class RandomStreamsInstance(object): ...@@ -128,9 +128,11 @@ class RandomStreamsInstance(object):
old_r_seed = seedgen.randint(2**30) old_r_seed = seedgen.randint(2**30)
old_r_container = self.memo[old_r].value old_r_container = self.memo[old_r].value
if old_r_container.value is None: if old_r_container.value is None:
old_r_container.value = numpy.random.RandomState(old_r_seed) #the cast to int here makes it work on 32bit machines, not sure why
old_r_container.value = numpy.random.RandomState(int(old_r_seed))
else: else:
old_r_container.value.seed(old_r_seed) #the cast to int here makes it work on 32bit machines, not sure why
old_r_container.value.seed(int(old_r_seed))
def __getitem__(self, item): def __getitem__(self, item):
"""Retrieve the numpy RandomState instance associated with a particular stream """Retrieve the numpy RandomState instance associated with a particular stream
......
...@@ -25,7 +25,7 @@ class T_RandomStreams(unittest.TestCase): ...@@ -25,7 +25,7 @@ class T_RandomStreams(unittest.TestCase):
gn_val0 = made.gn() gn_val0 = made.gn()
rng_seed = numpy.random.RandomState(234).randint(2**30) rng_seed = numpy.random.RandomState(234).randint(2**30)
rng = numpy.random.RandomState(rng_seed) rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0 #print fn_val0
numpy_val0 = rng.uniform(size=(2,2)) numpy_val0 = rng.uniform(size=(2,2))
...@@ -46,7 +46,7 @@ class T_RandomStreams(unittest.TestCase): ...@@ -46,7 +46,7 @@ class T_RandomStreams(unittest.TestCase):
fn_val1 = made.fn() fn_val1 = made.fn()
rng_seed = numpy.random.RandomState(888).randint(2**30) rng_seed = numpy.random.RandomState(888).randint(2**30)
rng = numpy.random.RandomState(rng_seed) rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0 #print fn_val0
numpy_val0 = rng.uniform(size=(2,2)) numpy_val0 = rng.uniform(size=(2,2))
...@@ -69,7 +69,7 @@ class T_RandomStreams(unittest.TestCase): ...@@ -69,7 +69,7 @@ class T_RandomStreams(unittest.TestCase):
fn_val1 = made.fn() fn_val1 = made.fn()
rng_seed = numpy.random.RandomState(888).randint(2**30) rng_seed = numpy.random.RandomState(888).randint(2**30)
rng = numpy.random.RandomState(rng_seed) rng = numpy.random.RandomState(int(rng_seed)) #int() is for 32bit
#print fn_val0 #print fn_val0
numpy_val0 = rng.uniform(size=(2,2)) numpy_val0 = rng.uniform(size=(2,2))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论