提交 bebe79eb authored 作者: Kumar Krishna Agrawal's avatar Kumar Krishna Agrawal

Updated documentation, added test

上级 4390245e
...@@ -3142,11 +3142,13 @@ def var(input, axis=None, keepdims=False, corrected=False): ...@@ -3142,11 +3142,13 @@ def var(input, axis=None, keepdims=False, corrected=False):
corrected : bool corrected : bool
If this is set to True, the 'corrected_two_pass' algorithm is If this is set to True, the 'corrected_two_pass' algorithm is
used to compute the variance. used to compute the variance.
Refer : http://www.cs.yale.edu/publications/techreports/tr222.pdf
Notes Notes
----- -----
It uses the two-pass algorithm for more stable results. By default, uses the two-pass algorithm for more stable results.
https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Two-pass_algorithm
Also supports 'corrected_two_pass' algorithm, as mentioned above.
There exist other implementations that are even more stable, but probably There exist other implementations that are even more stable, but probably
slower. slower.
......
...@@ -6319,8 +6319,6 @@ def test_var(): ...@@ -6319,8 +6319,6 @@ def test_var():
f = function([a], var(a)) f = function([a], var(a))
a_val = numpy.arange(60).reshape(3, 4, 5) a_val = numpy.arange(60).reshape(3, 4, 5)
# print numpy.var(a_val)
# print f(a_val)
assert numpy.allclose(numpy.var(a_val), f(a_val)) assert numpy.allclose(numpy.var(a_val), f(a_val))
f = function([a], var(a, axis=0)) f = function([a], var(a, axis=0))
...@@ -6333,7 +6331,13 @@ def test_var(): ...@@ -6333,7 +6331,13 @@ def test_var():
assert numpy.allclose(numpy.var(a_val, axis=2), f(a_val)) assert numpy.allclose(numpy.var(a_val, axis=2), f(a_val))
f = function([a], var(a, corrected=True)) f = function([a], var(a, corrected=True))
assert numpy.allclose(numpy.var(a_val), f(a_val)) mean_a = numpy.mean(a_val)
centered_a = a_val - mean_a
v = numpy.mean(centered_a ** 2)
error = (numpy.mean(centered_a)) ** 2
v = v - error
assert numpy.allclose(v, f(a_val))
class T_sum(unittest.TestCase): class T_sum(unittest.TestCase):
def test_sum_overflow(self): def test_sum_overflow(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论