提交 9691e746 authored 作者: David Warde-Farley's avatar David Warde-Farley

Merge pull request #361 from delallea/sum_dtype

Fixed gh-356: dtype of tensor.sum()
......@@ -1418,8 +1418,8 @@ class _tensor_py_operators:
def __rdot__(right, left):
return dot(left, right)
def sum(self, axis=None):
return elemwise.Sum(axis)(self)
def sum(self, *args, **kw):
return sum(self, *args, **kw)
def norm(self, L, axis=None):
if L==0:
......@@ -2579,13 +2579,21 @@ def tensor_copy(a):
"""Create a duplicate of `a` (with duplicated storage)"""
pprint.assign(tensor_copy, printing.IgnorePrinter())
@constructor
def sum(input, axis = None):
"""WRITEME"""
return elemwise.Sum(axis)(input)
def sum(input, axis=None, dtype=None):
"""
Sum a tensor along the given axis(es).
For full documentation see ``tensor.elemwise.Sum``.
In particular please pay attention to the important warning when using
a custom dtype.
"""
return elemwise.Sum(axis=axis, dtype=dtype)(input)
pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor
def prod(input, axis = None):
"""WRITEME"""
......
......@@ -1004,7 +1004,7 @@ class CAReduce(Op):
subtract, divide or power).
"""
def __init__(self, scalar_op, axis = None):
def __init__(self, scalar_op, axis=None):
"""
Usage: CAReduce(scalar_op, axis = None)
......@@ -1071,9 +1071,10 @@ class CAReduce(Op):
op = self.__class__(self.scalar_op, axis)
else:
op = self
output = TensorType(dtype = self._output_dtype(input.type.dtype),
broadcastable = [x for i, x in enumerate(input.type.broadcastable) if i not in axis])()
broadcastable = [x for i, x in enumerate(input.type.broadcastable)
if i not in axis]
output = TensorType(dtype=self._output_dtype(input.type.dtype),
broadcastable=broadcastable)()
return Apply(op, [input], [output])
def __getstate__(self):
......@@ -1315,26 +1316,62 @@ class Any(CAReduce):
class Sum(CAReduce):
"""
Sums all the values of a tensor along the specified axis(es).
Equivalent to CAReduce(scalar.add, axis = axis), with the
Equivalent to CAReduce(scalar.add, axis=axis), with the
difference that this defines the gradient of sum wrt its tensor
input.
"""
def __init__(self, axis = None):
def __init__(self, axis=None, dtype=None):
"""
Constructor.
:param axis: Axis(es) along which the tensor should be summed
(use None to sum over all axes, and a list or tuple to sum along more
than one axis).
:param dtype: The dtype of the internal accumulator and returned
tensor. If None, then we use the default dtype which is the same as the
input tensor's dtype except when:
- the input dtype is a signed integer of precision < 64 bit, in
which case we use int64
- the input dtype is an unsigned integer of precision < 64 bit, in
which case we use uint64
This behavior is similar in spirit to that of numpy (except numpy
uses the default machine integer while we always use 64 bit integers to
avoid platform-dependent behavior).
IMPORTANT: If you use a custom dtype (!= None), it is strongly advised
to set `config.on_opt_error` to 'raise' and to run your code in
DebugMode at least once. This is because some optimizations may not
currently be able to properly deal with such custom dtypes. Also please
note that using a custom dtype may prevent some optimizations from
being applied.
"""
CAReduce.__init__(self, scalar.add, axis)
self.dtype = dtype
def __eq__(self, other):
return CAReduce.__eq__(self, other) and self.dtype == other.dtype
def __hash__(self):
return CAReduce.__hash__(self) ^ hash(self.dtype)
def _output_dtype(self, idtype):
# we want to protect against overflow
return dict(
int8='int32',
int16='int32',
int32='int64',
uint8='uint32',
uint16='uint32',
uint32='uint64',
).get(idtype, idtype)
if self.dtype is None:
return dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
else:
return self.dtype
def grad(self, inp, grads):
x, = inp
......@@ -1353,7 +1390,8 @@ class Sum(CAReduce):
else:
new_dims.append(i)
i += 1
return Elemwise(scalar.second)(x, DimShuffle(gz.type.broadcastable, new_dims)(gz)),
return Elemwise(scalar.second)(
x, DimShuffle(gz.type.broadcastable, new_dims)(gz)),
def R_op(self, inputs, eval_points):
# There is just one element in inputs and eval_points, the axis are
......@@ -1368,6 +1406,7 @@ class Sum(CAReduce):
else:
return "Sum{%s}" % ", ".join(map(str, self.axis))
class Prod(CAReduce):
"""
Multiplies all the values of a tensor along the specified axis(es).
......
......@@ -2307,6 +2307,8 @@ if 0:
# that if we can prove the output to this sum has a
# zero-size dimension, then it can be replaced by an
# appropriately typed and broadcasted zero.
# TODO: Remember to take into account the new sum dtype argument if this
# optimization is enabled.
@register_canonicalize
@gof.local_optimizer([])
def local_sum_over_empty(node):
......@@ -2858,7 +2860,8 @@ def local_sum_mul_by_scalar(node):
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
if isinstance(node.op, T.Sum):
# TODO Implement for sum.dtype != None.
if isinstance(node.op, T.Sum) and node.op.dtype is None:
thing_summed, = node.inputs
if thing_summed.owner and thing_summed.owner.op == T.mul:
terms = thing_summed.owner.inputs
......@@ -2912,7 +2915,9 @@ def local_sum_div_dimshuffle(node):
# dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation.
if isinstance(node.op, T.Sum):
# TODO Implement for sum.dtype != None.
if isinstance(node.op, T.Sum) and node.op.dtype is None:
axis = node.op.axis
if axis is None:
axis = range(node.inputs[0].ndim)
......@@ -3001,17 +3006,25 @@ def local_sum_all_to_none(node):
if node.op.axis is None:
return
if set(node.op.axis) == set(range(node.inputs[0].type.ndim)):
return [T.Sum(axis=None)(node.inputs[0])]
return [T.Sum(axis=None, dtype=node.op.dtype)(node.inputs[0])]
@register_canonicalize
@gof.local_optimizer([])
def local_sum_sum(node):
"""Sum(Sum()) -> Sum"""
if isinstance(node.op, T.Sum):
"""
Sum(Sum()) -> Sum
Note that currently we only replace sums with default dtypes, to avoid
potential dtype conflict issues.
"""
if isinstance(node.op, T.Sum) and node.op.dtype is None:
summed, = node.inputs
if len(summed.clients) == 1:
if summed.owner and isinstance(summed.owner.op, T.Sum):
if (summed.owner and
isinstance(summed.owner.op, T.Sum)
and summed.owner.op.dtype is None):
if summed.owner.op.axis is None:
# special case of local_cut_useless_reduce
return [T.Sum(None)(summed.owner.inputs[0])]
......@@ -3080,7 +3093,8 @@ def local_cut_useless_reduce(node):
@gof.local_optimizer([])
def local_sum_alloc(node):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes)"""
if isinstance(node.op, T.Sum):
# TODO Implement for sum.dtype != None
if isinstance(node.op, T.Sum) and node.op.dtype is None:
summed, = node.inputs
if summed.owner and isinstance(summed.owner.op, T.Alloc):
input = summed.owner.inputs[0]
......
import cPickle, time, unittest
from itertools import imap
from numpy.testing import dec
......@@ -498,6 +499,40 @@ class test_IsInf_IsNan(unittest.TestCase):
return self.run_isfunc('isnan')
def test_sum_default_dtype():
"""
Test the default dtype of a sum().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).sum(axis=axis)
assert x.dtype == dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(dtype, dtype)
def test_sum_custom_dtype():
"""
Test the ability to provide your own output dtype for a sum.
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
idx = 0
for input_dtype in imap(str, theano.scalar.all_types):
x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)]
assert x.sum(dtype=output_dtype, axis=axis).dtype == output_dtype
idx += 1
if __name__ == '__main__':
#unittest.main()
suite = unittest.TestSuite([test_Prod('test_mul_without_zeros_zeros')])
......
......@@ -3134,6 +3134,22 @@ class T_local_sum(unittest.TestCase):
finally:
config.warn.sum_sum_bug = backup
def test_local_sum_sum_int8(self):
"""
Test that local_sum_sum works when combining two sums on an int8 array.
This is a regression test for ticket gh-356.
"""
x = tensor.tensor3(dtype='int8')
y = x.sum(axis=0).sum(axis=1)
backup = config.on_opt_error
config.on_opt_error = 'raise'
try:
# This compilation would fail prior to fix.
f = theano.function([x], y)
finally:
config.on_opt_error = backup
class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论