提交 c7ec0733 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Make sure the output dtype of uniform is the one given as dtype parameter, and…

Make sure the output dtype of uniform is the one given as dtype parameter, and that normal Python int/float do not upcast it
上级 86164f8c
...@@ -384,7 +384,7 @@ def _generate_broadcasting_indices(out_shape, *shapes): ...@@ -384,7 +384,7 @@ def _generate_broadcasting_indices(out_shape, *shapes):
return ret_indices return ret_indices
def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None, dtype=theano.config.floatX): def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None, dtype='floatX'):
""" """
Sample from a uniform distribution between low and high. Sample from a uniform distribution between low and high.
...@@ -394,12 +394,28 @@ def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None, dtype=theano. ...@@ -394,12 +394,28 @@ def uniform(random_state, size=None, low=0.0, high=1.0, ndim=None, dtype=theano.
If size is None, the output shape will be determined by the shapes If size is None, the output shape will be determined by the shapes
of low and high. of low and high.
""" """
low = tensor.as_tensor_variable(low) if dtype == 'floatX':
high = tensor.as_tensor_variable(high) dtype = theano.config.floatX
# Handle special case of untyped Python int / float: we cast them into
# `dtype` to ensure they cannot upcast the end result accidentally.
def as_tensor(x):
if isinstance(x, int) or isinstance(x, float):
return tensor.constant(x, dtype=dtype)
else:
return tensor.as_tensor_variable(x)
low = as_tensor(low)
high = as_tensor(high)
ndim, size, bcast = _infer_ndim_bcast(ndim, size, low, high) ndim, size, bcast = _infer_ndim_bcast(ndim, size, low, high)
dtype = tensor.scal.upcast(dtype, low.dtype, high.dtype) out_dtype = tensor.scal.upcast(dtype, low.dtype, high.dtype)
# It would be confusing if the resulting dtype was not the same as the one
# passed as argument to the function.
assert out_dtype == dtype, (
"Output dtype should be %s, but it would be %s given the provided "
"low (%s) and high (%s) arguments: cast those into a more "
"appropriate type to solve this issue" %
(dtype, out_dtype, low.dtype, high.dtype))
op = RandomFunction('uniform', op = RandomFunction('uniform',
tensor.TensorType(dtype=dtype, broadcastable=bcast) ) tensor.TensorType(dtype=out_dtype, broadcastable=bcast) )
return op(random_state, size, low, high) return op(random_state, size, low, high)
def binomial(random_state, size=None, n=1, p=0.5, ndim=None, dtype='int64', prob=None): def binomial(random_state, size=None, n=1, p=0.5, ndim=None, dtype='int64', prob=None):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论