提交 332cd5d4 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #439 from lamblin/finish_careduce_dtype

Update code for mean's dtype after code review
......@@ -2703,27 +2703,27 @@ def mean(input, axis=None, dtype=None, op=False):
necessarily be the dtype of the output (in particular
if it is a discrete (int/uint) dtype, the output will
be in a float type).
If None, then we use float64 for a discrete input, and the
same rules as `sum()` for a continuous input.
If None, then we use the same rules as `sum()`.
:type dtype: None or string
:note: for gpu, if you specify dtype=float32, everything will be done
on the gpu.
"""
if op:
if dtype not in (None, 'float64'):
raise NotImplementedError(
'The Mean op does not support the dtype argument, '
'and will always use float64. If you want to specify '
'the dtype, call tensor.mean(..., op=False).',
dtype)
return Mean(axis)(input)
if dtype is not None:
# The summation will be done with the specified dtype.
# sum() will complain if it is not suitable.
sum_dtype = dtype
elif input.dtype in discrete_dtypes:
# we need to cast eventually anyway, and this helps
# to prevents overflow. Numpy uses 'float64'.
# TODO: use floatX? let casting_policy decide?
sum_dtype = 'float64'
else:
# Let sum() infer the appropriate dtype
# Let sum() infer the appropriate dtype.
sum_dtype = None
s = sum(input, axis=axis, dtype=sum_dtype)
......@@ -2741,8 +2741,9 @@ def mean(input, axis=None, dtype=None, op=False):
axis = range(input.ndim)
elif isinstance(axis, int):
axis = [axis]
for i in axis:
s = s / shp[i]
s = true_div(s, shp[i])
return s
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论