提交 d5b67afd authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Resolved #687: fixed output dtype of uniform and normal random samples

This output dtype is now always of the type given by the dtype argument, and the other variables used in the computation (e.g. low & high for uniform) are now cast to this dtype to make sure they do not upcast the result.
上级 d6780ea5
...@@ -396,30 +396,15 @@ def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None, dtype='floatX ...@@ -396,30 +396,15 @@ def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None, dtype='floatX
""" """
if dtype == 'floatX': if dtype == 'floatX':
dtype = theano.config.floatX dtype = theano.config.floatX
# Handle special case of untyped Python int / float: we cast them into low = tensor.cast(tensor.as_tensor_variable(low), dtype)
# `dtype` to ensure they cannot upcast the end result accidentally. high = tensor.cast(tensor.as_tensor_variable(high), dtype)
def as_tensor(x):
if isinstance(x, int) or isinstance(x, float):
return tensor.constant(x, dtype=dtype)
else:
return tensor.as_tensor_variable(x)
low = as_tensor(low)
high = as_tensor(high)
ndim, size, bcast = _infer_ndim_bcast(ndim, size, low, high) ndim, size, bcast = _infer_ndim_bcast(ndim, size, low, high)
out_dtype = tensor.scal.upcast(dtype, low.dtype, high.dtype)
# It would be confusing if the resulting dtype was not the same as the one
# passed as argument to the function.
assert out_dtype == dtype, (
"Output dtype should be %s, but it would be %s given the provided "
"low (%s) and high (%s) arguments: cast those into a more "
"appropriate type to solve this issue" %
(dtype, out_dtype, low.dtype, high.dtype))
op = RandomFunction('uniform', op = RandomFunction('uniform',
tensor.TensorType(dtype=out_dtype, broadcastable=bcast) ) tensor.TensorType(dtype=dtype, broadcastable=bcast) )
return op(random_state, size, low, high) return op(random_state, size, low, high)
def normal(random_state, size=None, avg=0.0, std=1.0, ndim=None, dtype=theano.config.floatX): def normal(random_state, size=None, avg=0.0, std=1.0, ndim=None, dtype='floatX'):
""" """
Sample from a normal distribution centered on avg with Sample from a normal distribution centered on avg with
the specified standard deviation (std). the specified standard deviation (std).
...@@ -430,10 +415,11 @@ def normal(random_state, size=None, avg=0.0, std=1.0, ndim=None, dtype=theano.co ...@@ -430,10 +415,11 @@ def normal(random_state, size=None, avg=0.0, std=1.0, ndim=None, dtype=theano.co
If size is None, the output shape will be determined by the shapes If size is None, the output shape will be determined by the shapes
of avg and std. of avg and std.
""" """
avg = tensor.as_tensor_variable(avg) if dtype == 'floatX':
std = tensor.as_tensor_variable(std) dtype = theano.config.floatX
avg = tensor.cast(tensor.as_tensor_variable(avg), dtype)
std = tensor.cast(tensor.as_tensor_variable(std), dtype)
ndim, size, bcast = _infer_ndim_bcast(ndim, size, avg, std) ndim, size, bcast = _infer_ndim_bcast(ndim, size, avg, std)
dtype = tensor.scal.upcast(dtype, avg.dtype, std.dtype)
op = RandomFunction('normal', op = RandomFunction('normal',
tensor.TensorType(dtype=dtype, broadcastable=bcast)) tensor.TensorType(dtype=dtype, broadcastable=bcast))
return op(random_state, size, avg, std) return op(random_state, size, avg, std)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论