提交 be9cc434 authored 作者: abergeron's avatar abergeron 提交者: GitHub

Merge pull request #5144 from nouiz/float16_var

make var() return float16 when input is float16.
...@@ -3249,11 +3249,12 @@ def var(input, axis=None, ddof=0, keepdims=False, corrected=False): ...@@ -3249,11 +3249,12 @@ def var(input, axis=None, ddof=0, keepdims=False, corrected=False):
centered_input = input - mean_input centered_input = input - mean_input
# return the mean sqr # return the mean sqr
two = constant(2, dtype=centered_input.dtype)
if ddof == 0: if ddof == 0:
v = mean((centered_input ** 2), axis, keepdims=keepdims) v = mean((centered_input ** two), axis, keepdims=keepdims)
else: else:
shp = shape(input) - ddof shp = shape(input) - ddof
v = sum((centered_input ** 2), axis=axis, keepdims=keepdims) v = sum((centered_input ** two), axis=axis, keepdims=keepdims)
for i in axis: for i in axis:
v = true_div(v, shp[i]) v = true_div(v, shp[i])
......
...@@ -6400,6 +6400,9 @@ def test_var(): ...@@ -6400,6 +6400,9 @@ def test_var():
v = v - error v = v - error
assert numpy.allclose(v, f(a_val)) assert numpy.allclose(v, f(a_val))
# Test that we don't upcast float16 computation
assert theano.tensor.vector(dtype='float16').var().dtype == 'float16'
class T_sum(unittest.TestCase): class T_sum(unittest.TestCase):
def test_sum_overflow(self): def test_sum_overflow(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论