提交 b135df12 authored 作者: bergstra@ip05.m's avatar bergstra@ip05.m

merge

...@@ -686,12 +686,15 @@ class CAReduce(Op): ...@@ -686,12 +686,15 @@ class CAReduce(Op):
self.axis = axis self.axis = axis
self.ufunc = numpy.frompyfunc(scalar_op.impl, 2, 1) self.ufunc = numpy.frompyfunc(scalar_op.impl, 2, 1)
def _output_dtype(self, input_dtype):
return input_dtype
def make_node(self, input): def make_node(self, input):
input = as_tensor_variable(input) input = as_tensor_variable(input)
axis = self.axis axis = self.axis
if axis is None: if axis is None:
axis = range(len(input.type.broadcastable)) axis = range(len(input.type.broadcastable))
output = TensorType(dtype = input.type.dtype, output = TensorType(dtype = self._output_dtype(input.type.dtype),
broadcastable = [x for i, x in enumerate(input.type.broadcastable) if i not in axis])() broadcastable = [x for i, x in enumerate(input.type.broadcastable) if i not in axis])()
return Apply(self, [input], [output]) return Apply(self, [input], [output])
...@@ -817,6 +820,13 @@ class Sum(CAReduce): ...@@ -817,6 +820,13 @@ class Sum(CAReduce):
def __init__(self, axis = None): def __init__(self, axis = None):
CAReduce.__init__(self, scalar.add, axis) CAReduce.__init__(self, scalar.add, axis)
def _output_dtype(self, idtype):
if idtype.startswith('int'):
return 'int64' #we want to protect against overflow
else:
return idtype
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
gz = as_tensor_variable(gz) gz = as_tensor_variable(gz)
axis = self.axis axis = self.axis
......
...@@ -1887,6 +1887,12 @@ def test_var(): ...@@ -1887,6 +1887,12 @@ def test_var():
f = function([a], var(a, axis=2)) f = function([a], var(a, axis=2))
assert numpy.allclose(numpy.var(a_val, axis=2), f(a_val)) assert numpy.allclose(numpy.var(a_val, axis=2), f(a_val))
def test_sum_overflow():
"""Ensure that overflow errors are a little bit harder to get"""
a = Tensor(dtype='int8', broadcastable=[False])()
f = function([a], sum(a))
assert f([1]*300) == 300
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) >= 2 and sys.argv[1] == 'OPT': if len(sys.argv) >= 2 and sys.argv[1] == 'OPT':
default_mode = compile.Mode(linker = 'c&py', default_mode = compile.Mode(linker = 'c&py',
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论