提交 a2b8a144 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Fixed gh-356: dtype of tensor.sum()

The default dtype for signed integer input types is now int64 (and uint64 for unsigned integers). The default dtype can be overridden by specifying a dtype manually. However this does not appear to be very safe, so a warning was added to the docstrings when doing so. This commit also contains a few minor PEP8 fixes.
上级 34caad67
...@@ -1389,8 +1389,8 @@ class _tensor_py_operators: ...@@ -1389,8 +1389,8 @@ class _tensor_py_operators:
def __rdot__(right, left): def __rdot__(right, left):
return dot(left, right) return dot(left, right)
def sum(self, axis=None): def sum(self, *args, **kw):
return elemwise.Sum(axis)(self) return sum(self, *args, **kw)
def norm(self, L, axis=None): def norm(self, L, axis=None):
if L==0: if L==0:
...@@ -2550,13 +2550,21 @@ def tensor_copy(a): ...@@ -2550,13 +2550,21 @@ def tensor_copy(a):
"""Create a duplicate of `a` (with duplicated storage)""" """Create a duplicate of `a` (with duplicated storage)"""
pprint.assign(tensor_copy, printing.IgnorePrinter()) pprint.assign(tensor_copy, printing.IgnorePrinter())
@constructor @constructor
def sum(input, axis = None): def sum(input, axis=None, dtype=None):
"""WRITEME""" """
return elemwise.Sum(axis)(input) Sum a tensor along the given axis(es).
For full documentation see ``tensor.elemwise.Sum``.
In particular please pay attention to the important warning when using
a custom dtype.
"""
return elemwise.Sum(axis=axis, dtype=dtype)(input)
pprint.assign(Sum(), printing.FunctionPrinter('sum')) pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor @constructor
def prod(input, axis = None): def prod(input, axis = None):
"""WRITEME""" """WRITEME"""
......
...@@ -1003,7 +1003,7 @@ class CAReduce(Op): ...@@ -1003,7 +1003,7 @@ class CAReduce(Op):
subtract, divide or power). subtract, divide or power).
""" """
def __init__(self, scalar_op, axis = None): def __init__(self, scalar_op, axis=None):
""" """
Usage: CAReduce(scalar_op, axis = None) Usage: CAReduce(scalar_op, axis = None)
...@@ -1074,9 +1074,10 @@ class CAReduce(Op): ...@@ -1074,9 +1074,10 @@ class CAReduce(Op):
op = self.__class__(self.scalar_op, axis) op = self.__class__(self.scalar_op, axis)
else: else:
op = self op = self
output = TensorType(dtype = self._output_dtype(input.type.dtype), broadcastable = [x for i, x in enumerate(input.type.broadcastable)
broadcastable = [x for i, x in enumerate(input.type.broadcastable) if i not in axis])() if i not in axis]
output = TensorType(dtype=self._output_dtype(input.type.dtype),
broadcastable=broadcastable)()
return Apply(op, [input], [output]) return Apply(op, [input], [output])
def __getstate__(self): def __getstate__(self):
...@@ -1307,26 +1308,62 @@ class Any(CAReduce): ...@@ -1307,26 +1308,62 @@ class Any(CAReduce):
class Sum(CAReduce): class Sum(CAReduce):
""" """
Sums all the values of a tensor along the specified axis(es). Sums all the values of a tensor along the specified axis(es).
Equivalent to CAReduce(scalar.add, axis = axis), with the Equivalent to CAReduce(scalar.add, axis=axis), with the
difference that this defines the gradient of sum wrt its tensor difference that this defines the gradient of sum wrt its tensor
input. input.
""" """
def __init__(self, axis = None):
def __init__(self, axis=None, dtype=None):
"""
Constructor.
:param axis: Axis(es) along which the tensor should be summed
(use None to sum over all axes, and a list or tuple to sum along more
than one axis).
:param dtype: The dtype of the internal accumulator and returned
tensor. If None, then we use the default dtype which is the same as the
input tensor's dtype except when:
- the input dtype is a signed integer of precision < 64 bit, in
which case we use int64
- the input dtype is an unsigned integer of precision < 64 bit, in
which case we use uint64
This behavior is similar in spirit to that of numpy (except numpy
uses the default machine integer while we always use 64 bit integers to
avoid platform-dependent behavior).
IMPORTANT: If you use a custom dtype (!= None), it is strongly advised
to set `config.on_opt_error` to 'raise' and to run your code in
DebugMode at least once. This is because some optimizations may not
currently be able to properly deal with such custom dtypes. Also please
note that using a custom dtype may prevent some optimizations from
being applied.
"""
CAReduce.__init__(self, scalar.add, axis) CAReduce.__init__(self, scalar.add, axis)
self.dtype = dtype
def __eq__(self, other):
return CAReduce.__eq__(self, other) and self.dtype == other.dtype
def __hash__(self):
return CAReduce.__hash__(self) ^ hash(self.dtype)
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
# we want to protect against overflow if self.dtype is None:
return dict( return dict(
int8='int32', int8='int64',
int16='int32', int16='int64',
int32='int64', int32='int64',
uint8='uint32', uint8='uint64',
uint16='uint32', uint16='uint64',
uint32='uint64', uint32='uint64',
).get(idtype, idtype) ).get(idtype, idtype)
else:
return self.dtype
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
...@@ -1345,7 +1382,8 @@ class Sum(CAReduce): ...@@ -1345,7 +1382,8 @@ class Sum(CAReduce):
else: else:
new_dims.append(i) new_dims.append(i)
i += 1 i += 1
return Elemwise(scalar.second)(x, DimShuffle(gz.type.broadcastable, new_dims)(gz)), return Elemwise(scalar.second)(
x, DimShuffle(gz.type.broadcastable, new_dims)(gz)),
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
# There is just one element in inputs and eval_points, the axis are # There is just one element in inputs and eval_points, the axis are
...@@ -1360,6 +1398,7 @@ class Sum(CAReduce): ...@@ -1360,6 +1398,7 @@ class Sum(CAReduce):
else: else:
return "Sum{%s}" % ", ".join(map(str, self.axis)) return "Sum{%s}" % ", ".join(map(str, self.axis))
class Prod(CAReduce): class Prod(CAReduce):
""" """
Multiplies all the values of a tensor along the specified axis(es). Multiplies all the values of a tensor along the specified axis(es).
......
...@@ -2301,6 +2301,8 @@ if 0: ...@@ -2301,6 +2301,8 @@ if 0:
# that if we can prove the output to this sum has a # that if we can prove the output to this sum has a
# zero-size dimension, then it can be replaced by an # zero-size dimension, then it can be replaced by an
# appropriately typed and broadcasted zero. # appropriately typed and broadcasted zero.
# TODO: Remember to take into account the new sum dtype argument if this
# optimization is enabled.
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_sum_over_empty(node): def local_sum_over_empty(node):
...@@ -2852,7 +2854,8 @@ def local_sum_mul_by_scalar(node): ...@@ -2852,7 +2854,8 @@ def local_sum_mul_by_scalar(node):
""" """
# TODO: if the the thing inside the Sum is a division, # TODO: if the the thing inside the Sum is a division,
# we should get at the numerator.... # we should get at the numerator....
if isinstance(node.op, T.Sum): # TODO Implement for sum.dtype != None.
if isinstance(node.op, T.Sum) and node.op.dtype is None:
thing_summed, = node.inputs thing_summed, = node.inputs
if thing_summed.owner and thing_summed.owner.op == T.mul: if thing_summed.owner and thing_summed.owner.op == T.mul:
terms = thing_summed.owner.inputs terms = thing_summed.owner.inputs
...@@ -2906,7 +2909,9 @@ def local_sum_div_dimshuffle(node): ...@@ -2906,7 +2909,9 @@ def local_sum_div_dimshuffle(node):
# dimshuffle is in the numerator, since elemwise inversion of the # dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation. # denominator would still be needed before the summation.
if isinstance(node.op, T.Sum): # TODO Implement for sum.dtype != None.
if isinstance(node.op, T.Sum) and node.op.dtype is None:
axis = node.op.axis axis = node.op.axis
if axis is None: if axis is None:
axis = range(node.inputs[0].ndim) axis = range(node.inputs[0].ndim)
...@@ -2995,17 +3000,25 @@ def local_sum_all_to_none(node): ...@@ -2995,17 +3000,25 @@ def local_sum_all_to_none(node):
if node.op.axis is None: if node.op.axis is None:
return return
if set(node.op.axis) == set(range(node.inputs[0].type.ndim)): if set(node.op.axis) == set(range(node.inputs[0].type.ndim)):
return [T.Sum(axis=None)(node.inputs[0])] return [T.Sum(axis=None, dtype=node.op.dtype)(node.inputs[0])]
@register_canonicalize @register_canonicalize
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_sum_sum(node): def local_sum_sum(node):
"""Sum(Sum()) -> Sum""" """
if isinstance(node.op, T.Sum): Sum(Sum()) -> Sum
Note that currently we only replace sums with default dtypes, to avoid
potential dtype conflict issues.
"""
if isinstance(node.op, T.Sum) and node.op.dtype is None:
summed, = node.inputs summed, = node.inputs
if len(summed.clients) == 1: if len(summed.clients) == 1:
if summed.owner and isinstance(summed.owner.op, T.Sum): if (summed.owner and
isinstance(summed.owner.op, T.Sum)
and summed.owner.op.dtype is None):
if summed.owner.op.axis is None: if summed.owner.op.axis is None:
# special case of local_cut_useless_reduce # special case of local_cut_useless_reduce
return [T.Sum(None)(summed.owner.inputs[0])] return [T.Sum(None)(summed.owner.inputs[0])]
...@@ -3074,7 +3087,8 @@ def local_cut_useless_reduce(node): ...@@ -3074,7 +3087,8 @@ def local_cut_useless_reduce(node):
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_sum_alloc(node): def local_sum_alloc(node):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes)""" """ sum(alloc(constant,shapes...)) => constant*prod(shapes)"""
if isinstance(node.op, T.Sum): # TODO Implement for sum.dtype != None
if isinstance(node.op, T.Sum) and node.op.dtype is None:
summed, = node.inputs summed, = node.inputs
if summed.owner and isinstance(summed.owner.op, T.Alloc): if summed.owner and isinstance(summed.owner.op, T.Alloc):
input = summed.owner.inputs[0] input = summed.owner.inputs[0]
......
import cPickle, time, unittest import cPickle, time, unittest
from itertools import imap
from numpy.testing import dec from numpy.testing import dec
...@@ -498,6 +499,40 @@ class test_IsInf_IsNan(unittest.TestCase): ...@@ -498,6 +499,40 @@ class test_IsInf_IsNan(unittest.TestCase):
return self.run_isfunc('isnan') return self.run_isfunc('isnan')
def test_sum_default_dtype():
"""
Test the default dtype of a sum().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).sum(axis=axis)
assert x.dtype == dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(dtype, dtype)
def test_sum_custom_dtype():
"""
Test the ability to provide your own output dtype for a sum.
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
idx = 0
for input_dtype in imap(str, theano.scalar.all_types):
x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)]
assert x.sum(dtype=output_dtype, axis=axis).dtype == output_dtype
idx += 1
if __name__ == '__main__': if __name__ == '__main__':
#unittest.main() #unittest.main()
suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')]) suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')])
......
...@@ -3134,6 +3134,22 @@ class T_local_sum(unittest.TestCase): ...@@ -3134,6 +3134,22 @@ class T_local_sum(unittest.TestCase):
finally: finally:
config.warn.sum_sum_bug = backup config.warn.sum_sum_bug = backup
def test_local_sum_sum_int8(self):
"""
Test that local_sum_sum works when combining two sums on an int8 array.
This is a regression test for ticket gh-356.
"""
x = tensor.tensor3(dtype='int8')
y = x.sum(axis=0).sum(axis=1)
backup = config.on_opt_error
config.on_opt_error = 'raise'
try:
# This compilation would fail prior to fix.
f = theano.function([x], y)
finally:
config.on_opt_error = backup
class T_local_sum_dimshuffle(unittest.TestCase): class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论