提交 d7cf98bf authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add 'dtype' keyword arg in tensor.mean().

This enables specifying the dtype in which the internal summation should be done, like in Numpy.
上级 8c02d5a6
...@@ -1430,9 +1430,9 @@ class _tensor_py_operators: ...@@ -1430,9 +1430,9 @@ class _tensor_py_operators:
#optimizations will/should catch cases like L=1, L=2 #optimizations will/should catch cases like L=1, L=2
return pow(pow(abs_(self), L).sum(axis=axis), 1.0/L) return pow(pow(abs_(self), L).sum(axis=axis), 1.0/L)
def mean(self, axis=None): def mean(self, axis=None, dtype=None):
"""See `theano.tensor.mean`""" """See `theano.tensor.mean`"""
return mean(self, axis) return mean(self, axis=axis, dtype=dtype)
def var(self, axis=None): def var(self, axis=None):
"""See `theano.tensor.var`""" """See `theano.tensor.var`"""
...@@ -2668,35 +2668,54 @@ class Mean(elemwise.CAReduce): ...@@ -2668,35 +2668,54 @@ class Mean(elemwise.CAReduce):
# return grad(mean(x, self.axis, op=False),[x]) # return grad(mean(x, self.axis, op=False),[x])
@constructor @constructor
def mean(input, axis = None, op = False): def mean(input, axis=None, dtype=None, op=False):
"""Compute the mean value along the given axis of a tensor `input` """Compute the mean value along the given axis of a tensor `input`
:param axis: compute the mean along this axis of the tensor. :param axis: compute the mean along this axis of the tensor.
None means all axes (like numpy). None means all axes (like numpy).
:type axis: None or int or (list of int) (see `Sum`) :type axis: None or int or (list of int) (see `Sum`)
:note: for gpu, if you manually cast the input to float32 before calling :param dtype: dtype to use for the inner summation. This will not
mean, everything will be done on the gpu. necessarily be the dtype of the output (in particular
if it is a discrete (int/uint) dtype, the output will
be in a float type)
:type dtype: string
:note: for gpu, if you specify dtype=float32, everything will be done
on the gpu.
""" """
if op: if op:
return Mean(axis)(input) return Mean(axis)(input)
if str(input.dtype) in discrete_dtypes: if dtype is not None:
# The summation will be done with the specified dtype.
# sum() will complain if it is not suitable.
sum_dtype = dtype
elif input.dtype in discrete_dtypes:
# we need to cast eventually anyway, and this helps # we need to cast eventually anyway, and this helps
# to prevents overflow # to prevents overflow. Numpy uses 'float64'.
input = cast(input, 'float64') # TODO: use floatX? let casting_policy decide?
s = sum(input, axis) sum_dtype = 'float64'
else:
# Let sum() infer the appropriate dtype
sum_dtype = None
s = sum(input, axis=axis, dtype=sum_dtype)
shp = shape(input) shp = shape(input)
if input.dtype == 'float32':
# Cast shp into a float type
if s.dtype in ('float32', 'complex64'):
shp = cast(shp, 'float32') shp = cast(shp, 'float32')
else:
shp = cast(shp, 'float64')
if axis is None: if axis is None:
axis = range(input.ndim) axis = range(input.ndim)
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
for i in axis: for i in axis:
s = s / shp[i] s = s / shp[i]
if str(input.dtype).startswith('float'):
assert input.dtype == s.dtype
return s return s
@constructor @constructor
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论