提交 4d6bae3e authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add acc_dtype to the interface of prod/sum/mean

上级 fffe6d61
......@@ -1826,13 +1826,15 @@ class _tensor_py_operators:
dot = __dot__
def sum(self, axis=None, dtype=None, keepdims=False):
def sum(self, axis=None, dtype=None, acc_dtype=None, keepdims=False):
"""See `theano.tensor.sum`"""
return sum(self, axis=axis, dtype=dtype, keepdims=keepdims)
return sum(self, axis=axis, dtype=dtype, acc_dtype=acc_dtype,
keepdims=keepdims)
def prod(self, axis=None, dtype=None, keepdims=False):
def prod(self, axis=None, dtype=None, acc_dtype=None, keepdims=False):
"""See `theano.tensor.prod`"""
return prod(self, axis=axis, dtype=dtype, keepdims=keepdims)
return prod(self, axis=axis, dtype=dtype, acc_dtype=acc_dtype,
keepdims=keepdims)
def norm(self, L, axis=None):
if L == 0:
......@@ -1842,9 +1844,10 @@ class _tensor_py_operators:
# optimizations will/should catch cases like L=1, L=2
return pow(pow(abs_(self), L).sum(axis=axis), 1.0 / L)
def mean(self, axis=None, dtype=None, keepdims=False):
def mean(self, axis=None, dtype=None, acc_dtype=None, keepdims=False):
"""See `theano.tensor.mean`"""
return mean(self, axis=axis, dtype=dtype, keepdims=keepdims)
return mean(self, axis=axis, dtype=dtype, acc_dtype=acc_dtype,
keepdims=keepdims)
def var(self, axis=None, keepdims=False):
"""See `theano.tensor.var`"""
......@@ -3777,7 +3780,7 @@ pprint.assign(tensor_copy, printing.IgnorePrinter())
@constructor
def sum(input, axis=None, dtype=None, keepdims=False):
def sum(input, axis=None, dtype=None, acc_dtype=None, keepdims=False):
"""
Computes the sum along the given axis(es) of a tensor `input`
......@@ -3790,10 +3793,10 @@ def sum(input, axis=None, dtype=None, keepdims=False):
For full documentation see ``tensor.elemwise.Sum``.
In particular please pay attention to the important warning when using
a custom dtype.
a custom acc_dtype.
"""
out = elemwise.Sum(axis=axis, dtype=dtype)(input)
out = elemwise.Sum(axis=axis, dtype=dtype, acc_dtype=acc_dtype)(input)
if keepdims:
out = makeKeepDims(input, out, axis)
......@@ -3803,7 +3806,7 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor
def prod(input, axis=None, dtype=None, keepdims=False):
def prod(input, axis=None, dtype=None, acc_dtype=None, keepdims=False):
"""
Computes the product along the given axis(es) of a tensor `input`
......@@ -3817,7 +3820,7 @@ def prod(input, axis=None, dtype=None, keepdims=False):
For full documentation see ``tensor.elemwise.Prod``.
"""
out = elemwise.Prod(axis, dtype=dtype)(input)
out = elemwise.Prod(axis, dtype=dtype, acc_dtype=acc_dtype)(input)
if keepdims:
out = makeKeepDims(input, out, axis)
......@@ -3868,7 +3871,8 @@ class Mean(elemwise.CAReduce):
@constructor
def mean(input, axis=None, dtype=None, op=False, keepdims=False):
def mean(input, axis=None, dtype=None, acc_dtype=None, op=False,
keepdims=False):
"""
Computes the mean value along the given axis(es) of a tensor `input`
......@@ -3876,13 +3880,19 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False):
None means all axes (like numpy).
:type axis: None or int or (list of int) (see `Sum`)
:param dtype: dtype to use for the inner summation. This will not
:param acc_dtype: dtype to use for the inner summation. This will not
necessarily be the dtype of the output (in particular
if it is a discrete (int/uint) dtype, the output will
be in a float type).
If None, then we use the same rules as `sum()`.
:type dtype: None or string
:param dtype: dtype to cast the result of the inner summation into.
For instance, by default, a sum of a float32 tensor will be
done in float64 (acc_dtype would be float64 by default),
but that result will be casted back in float32.
:type dtype: None or string
:param keepdims: If this is set to True, the axes which are reduced are
left in the result as dimensions with size one. With this option,
the result will broadcast correctly against the original tensor.
......@@ -3898,6 +3908,12 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False):
'and will always use float64. If you want to specify '
'the dtype, call tensor.mean(..., op=False).',
dtype)
if acc_dtype not in (None, 'float64'):
raise NotImplementedError(
'The Mean op does not support the acc_dtype argument, '
'and will always use float64. If you want to specify '
'acc_dtype, call tensor.mean(..., op=False).',
dtype)
out = Mean(axis)(input)
if keepdims:
out = makeKeepDims(input, out, axis)
......@@ -3911,7 +3927,8 @@ def mean(input, axis=None, dtype=None, op=False, keepdims=False):
# Let sum() infer the appropriate dtype.
sum_dtype = None
s = sum(input, axis=axis, dtype=sum_dtype, keepdims=keepdims)
s = sum(input, axis=axis, dtype=sum_dtype, acc_dtype=acc_dtype,
keepdims=keepdims)
shp = shape(input)
# Cast shp into a float type
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论