提交 66681c99 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged

......@@ -384,7 +384,7 @@ def _generate_broadcasting_indices(out_shape, *shapes):
return ret_indices
def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None, dtype='floatX'):
def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None, dtype=None):
"""
Sample from a uniform distribution between low and high.
......@@ -393,18 +393,21 @@ def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None, dtype='floatX
If size is None, the output shape will be determined by the shapes
of low and high.
If dtype is not specified, it will be inferred from the dtype of
low and high, but will be at least as precise as floatX.
"""
if dtype == 'floatX':
dtype = theano.config.floatX
low = tensor.cast(tensor.as_tensor_variable(low), dtype)
high = tensor.cast(tensor.as_tensor_variable(high), dtype)
low = tensor.as_tensor_variable(low)
high = tensor.as_tensor_variable(high)
if dtype is None:
dtype = tensor.scal.upcast(theano.config.floatX, low.dtype, high.dtype)
ndim, size, bcast = _infer_ndim_bcast(ndim, size, low, high)
op = RandomFunction('uniform',
tensor.TensorType(dtype=dtype, broadcastable=bcast) )
tensor.TensorType(dtype=dtype, broadcastable=bcast))
return op(random_state, size, low, high)
def normal(random_state, size=None, avg=0.0, std=1.0, ndim=None, dtype='floatX'):
def normal(random_state, size=None, avg=0.0, std=1.0, ndim=None, dtype=None):
"""
Sample from a normal distribution centered on avg with
the specified standard deviation (std).
......@@ -414,11 +417,14 @@ def normal(random_state, size=None, avg=0.0, std=1.0, ndim=None, dtype='floatX')
If size is None, the output shape will be determined by the shapes
of avg and std.
If dtype is not specified, it will be inferred from the dtype of
avg and std, but will be at least as precise as floatX.
"""
if dtype == 'floatX':
dtype = theano.config.floatX
avg = tensor.cast(tensor.as_tensor_variable(avg), dtype)
std = tensor.cast(tensor.as_tensor_variable(std), dtype)
avg = tensor.as_tensor_variable(avg)
std = tensor.as_tensor_variable(std)
if dtype == None:
dtype = tensor.scal.upcast(theano.config.floatX, avg.dtype, std.dtype)
ndim, size, bcast = _infer_ndim_bcast(ndim, size, avg, std)
op = RandomFunction('normal',
tensor.TensorType(dtype=dtype, broadcastable=bcast))
......@@ -722,7 +728,7 @@ class RandomStreamsBase(object):
print >> sys.stderr, "DEPRECATION WARNING: the parameter prob to the binomal fct have been renamed to p to have the same name as numpy."
return self.gen(binomial, size, n, p, ndim=ndim, dtype=dtype)
def uniform(self, size=None, low=0.0, high=1.0, ndim=None, dtype=theano.config.floatX):
def uniform(self, size=None, low=0.0, high=1.0, ndim=None, dtype=None):
"""
Sample a tensor of given size whose element from a uniform
distribution between low and high.
......@@ -733,7 +739,7 @@ class RandomStreamsBase(object):
"""
return self.gen(uniform, size, low, high, ndim=ndim, dtype=dtype)
def normal(self, size=None, avg=0.0, std=1.0, ndim=None, dtype=theano.config.floatX):
def normal(self, size=None, avg=0.0, std=1.0, ndim=None, dtype=None):
"""
Sample from a normal distribution centered on avg with
the specified standard deviation (std).
......
......@@ -611,6 +611,33 @@ class T_SharedRandomStreams(unittest.TestCase):
assert val1.dtype == 'int8'
assert numpy.all(abs(val1) <= 1)
def test_default_dtype(self):
random = RandomStreams(utt.fetch_seed())
low = tensor.dscalar()
high = tensor.dscalar()
# Should not silently downcast from low and high
out0 = random.uniform(low=low, high=high, size=(42,))
assert out0.dtype == 'float64'
f0 = function([low, high], out0)
val0 = f0(-2.1, 3.1)
assert val0.dtype == 'float64'
# Should downcast, since asked explicitly
out1 = random.uniform(low=low, high=high, size=(42,), dtype='float32')
assert out1.dtype == 'float32'
f1 = function([low, high], out1)
val1 = f1(-1.1, 1.1)
assert val1.dtype == 'float32'
# Should use floatX
lowf = tensor.fscalar()
highf = tensor.fscalar()
outf = random.uniform(low=lowf, high=highf, size=(42,))
assert outf.dtype == config.floatX
ff = function([lowf, highf], outf)
valf = ff(numpy.float32(-0.1), numpy.float32(0.3))
assert valf.dtype == config.floatX
def test_shared_constructor_borrow(self):
rng = numpy.random.RandomState(123)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论