提交 d43eb7e2 authored 作者: Kumar Krishna Agrawal's avatar Kumar Krishna Agrawal

Added support for ddof, corrected in var(), std(). Fixed docstring

上级 d311ac14
...@@ -3186,16 +3186,17 @@ def var(input, axis=None, ddof=0, keepdims=False, corrected=False): ...@@ -3186,16 +3186,17 @@ def var(input, axis=None, ddof=0, keepdims=False, corrected=False):
Notes Notes
----- -----
By default, uses the two-pass algorithm for more stable results. Default uses the two-pass algorithm (reference below).
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm
Also supports 'corrected_two_pass' algorithm, as mentioned above. Also supports 'corrected_two_pass' algorithm (using the 'corrected' flag)
There exist other implementations that are even more stable, but probably which is numerically more stable. There exist other implementations that
slower. offer better stability, but probably slower.
""" """
if isinstance(ddof, (bool)): if isinstance(ddof, (bool)):
raise ValueError('Parameter keepdims is now at index 3: (input, axis=None, ddof=0, keepdims=False)') raise ValueError('Parameter keepdims is now at index 3: (input, \
axis=None, ddof=0, keepdims=False, corrected=False)')
input_ndim = input.type.ndim input_ndim = input.type.ndim
if axis is None: if axis is None:
...@@ -3239,34 +3240,42 @@ def var(input, axis=None, ddof=0, keepdims=False, corrected=False): ...@@ -3239,34 +3240,42 @@ def var(input, axis=None, ddof=0, keepdims=False, corrected=False):
@constructor @constructor
def std(input, axis=None, ddof=0, keepdims=False): def std(input, axis=None, ddof=0, keepdims=False, corrected=False):
""" """
Computes the standard deviation along the given axis(es) of a tensor `input`. Computes the standard deviation along the given axis(es) of a tensor `input`.
Parameters Parameters
---------- ----------
axis : None or int or (list of int) (see `Sum`) axis: None or int or (list of int) (see `Sum`)
Compute the standard deviation along this axis of the tensor. Compute the variance along this axis of the tensor.
None means all axes (like numpy). None means all axes (like numpy).
ddof: Degrees of freedom; 0 would compute the ML estimate, 1 would compute
the unbiased estimate.
keepdims : bool keepdims : bool
If this is set to True, the axes which are reduced are left in the If this is set to True, the axes which are reduced are
result as dimensions with size one. With this option, the result will left in the result as dimensions with size one. With this option,
broadcast correctly against the original tensor. the result will broadcast correctly against the original tensor.
corrected : bool
If this is set to True, the 'corrected_two_pass' algorithm is
used to compute the variance.
Refer : http://www.cs.yale.edu/publications/techreports/tr222.pdf
Notes Notes
----- -----
It calls `var()` and `var()` uses the two-pass algorithm for more stable It calls 'var()' and 'var()' uses the two-pass algorithm (reference below).
results.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm
There exist other implementations that are even more stable, but probably Function 'var()' also supports 'corrected_two_pass' algorithm (using the
slower. 'corrected' flag) which is numerically more stable. There exist other
implementations that offer better stability, but probably slower.
""" """
if isinstance(ddof, (bool)): if isinstance(ddof, (bool)):
raise ValueError('Parameter keepdims is now at index 3: (input, axis=None, ddof=0, keepdims=False)') raise ValueError('Parameter keepdims is now at index 3: (input, \
axis=None, ddof=0, keepdims=False, corrected=False)')
ret = sqrt(var(input=input, axis=axis, ddof=ddof, keepdims=keepdims)) ret = sqrt(var(input=input, axis=axis, ddof=ddof,
keepdims=keepdims, corrected=corrected))
ret.name = 'std' ret.name = 'std'
return ret return ret
......
...@@ -643,14 +643,15 @@ class _tensor_py_operators(object): ...@@ -643,14 +643,15 @@ class _tensor_py_operators(object):
dtype=dtype, keepdims=keepdims, dtype=dtype, keepdims=keepdims,
acc_dtype=acc_dtype) acc_dtype=acc_dtype)
def var(self, axis=None, keepdims=False, corrected=False): def var(self, axis=None, ddof=0, keepdims=False, corrected=False):
"""See `theano.tensor.var`.""" """See `theano.tensor.var`."""
return theano.tensor.basic.var(self, axis, keepdims=keepdims, return theano.tensor.basic.var(self, axis=axis, ddof=ddof,
corrected=corrected) keepdims=keepdims, corrected=corrected)
def std(self, axis=None, keepdims=False): def std(self, axis=None, ddof=0, keepdims=False, corrected=False):
"""See `theano.tensor.std`.""" """See `theano.tensor.std`."""
return theano.tensor.basic.std(self, axis, keepdims=keepdims) return theano.tensor.basic.std(self, axis=axis, ddof=ddof,
keepdims=keepdims, corrected=corrected)
def min(self, axis=None, keepdims=False): def min(self, axis=None, keepdims=False):
"""See `theano.tensor.min`.""" """See `theano.tensor.min`."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论