提交 0e861a54 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

New keyword to explicitly choose a number of streams in uniform()

上级 008ce348
...@@ -49,7 +49,7 @@ MASK12 = numpy.int32(511) #2^9 - 1 ...@@ -49,7 +49,7 @@ MASK12 = numpy.int32(511) #2^9 - 1
MASK13 = numpy.int32(16777215) #2^24 - 1 MASK13 = numpy.int32(16777215) #2^24 - 1
MASK2 = numpy.int32(65535) #2^16 - 1 MASK2 = numpy.int32(65535) #2^16 - 1
MULT2 = numpy.int32(21069) MULT2 = numpy.int32(21069)
NORM = 4.656612873077392578125e-10; NORM = 4.656612873077392578125e-10; #1./2^31
A1p0 = numpy.asarray([[0, 4194304, 129], [1, 0, 0], [0, 1, 0]]) A1p0 = numpy.asarray([[0, 4194304, 129], [1, 0, 0], [0, 1, 0]])
A2p0 = numpy.asarray([[32768, 0, 32769], [1, 0, 0], [0, 1, 0]]) A2p0 = numpy.asarray([[32768, 0, 32769], [1, 0, 0], [0, 1, 0]])
...@@ -593,6 +593,7 @@ class MRG_RandomStreams(object): ...@@ -593,6 +593,7 @@ class MRG_RandomStreams(object):
return rval return rval
def n_streams(self, size): def n_streams(self, size):
# TODO: a smart way of choosing the number of streams
if isinstance(size, (tuple, list)): if isinstance(size, (tuple, list)):
r = 1 r = 1
for s in size: for s in size:
...@@ -601,12 +602,7 @@ class MRG_RandomStreams(object): ...@@ -601,12 +602,7 @@ class MRG_RandomStreams(object):
return r/6 # chosen as fastest for rbm_benchmark return r/6 # chosen as fastest for rbm_benchmark
else: else:
return r return r
try:
rval = int(size)
assert rval > 0
return rval
except:
pass
print >> sys.stderr, "MRG_RandomStreams Can't determine #streams from size (%s), guessing 30*256"%str(size) print >> sys.stderr, "MRG_RandomStreams Can't determine #streams from size (%s), guessing 30*256"%str(size)
return 30*256 return 30*256
...@@ -616,7 +612,7 @@ class MRG_RandomStreams(object): ...@@ -616,7 +612,7 @@ class MRG_RandomStreams(object):
node_rstate.default_update = new_rstate node_rstate.default_update = new_rstate
return sample return sample
def uniform(self, size=None, low=0.0, high=1.0, ndim=None, dtype=config.floatX): def uniform(self, size=None, low=0.0, high=1.0, ndim=None, dtype=config.floatX, nstreams=None):
""" """
Sample a tensor of given size whose element from a uniform Sample a tensor of given size whose element from a uniform
distribution between low and high. distribution between low and high.
...@@ -625,8 +621,10 @@ class MRG_RandomStreams(object): ...@@ -625,8 +621,10 @@ class MRG_RandomStreams(object):
ndim may be a plain integer to supplement the missing ndim may be a plain integer to supplement the missing
information. information.
""" """
if nstreams is None:
nstreams = self.n_streams(size)
if self.use_cuda and dtype=='float32': if self.use_cuda and dtype=='float32':
rstates = self.get_substream_rstates(self.n_streams(size)) rstates = self.get_substream_rstates(nstreams)
rstates = rstates.flatten() rstates = rstates.flatten()
# HACK - we use fact that int32 and float32 have same size to # HACK - we use fact that int32 and float32 have same size to
# sneak ints into the CudaNdarray type. # sneak ints into the CudaNdarray type.
...@@ -643,11 +641,11 @@ class MRG_RandomStreams(object): ...@@ -643,11 +641,11 @@ class MRG_RandomStreams(object):
u = self.pretty_return(node_rstate, u = self.pretty_return(node_rstate,
*GPU_mrg_uniform.new(node_rstate, ndim, dtype, size)) *GPU_mrg_uniform.new(node_rstate, ndim, dtype, size))
else: else:
node_rstate = shared(self.get_substream_rstates(self.n_streams(size))) node_rstate = shared(self.get_substream_rstates(nstreams))
u = self.pretty_return(node_rstate, u = self.pretty_return(node_rstate,
*mrg_uniform.new(node_rstate, ndim, dtype, size)) *mrg_uniform.new(node_rstate, ndim, dtype, size))
r = u * (high-low) + low r = u * (high-low) + low
if u.type.broadcastable != r.type.broadcastable: if u.type.broadcastable != r.type.broadcastable:
raise NotImplementedError( 'Increase the size to match the broadcasting pattern of `low` and `high` arguments') raise NotImplementedError( 'Increase the size to match the broadcasting pattern of `low` and `high` arguments')
return r return r
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论