提交 01ab8831 authored 作者: nouiz's avatar nouiz

Merge pull request #402 from lamblin/tensor_careduce_dtype

dtype for different Tensor CAReduce operations
......@@ -1418,8 +1418,13 @@ class _tensor_py_operators:
def __rdot__(right, left):
return dot(left, right)
def sum(self, *args, **kw):
return sum(self, *args, **kw)
def sum(self, axis=None, dtype=None):
"""See `theano.tensor.sum`"""
return sum(self, axis=axis, dtype=dtype)
def prod(self, axis=None, dtype=None):
"""See `theano.tensor.prod`"""
return prod(self, axis=axis, dtype=dtype)
def norm(self, L, axis=None):
if L==0:
......@@ -1429,9 +1434,9 @@ class _tensor_py_operators:
#optimizations will/should catch cases like L=1, L=2
return pow(pow(abs_(self), L).sum(axis=axis), 1.0/L)
def mean(self, axis=None):
def mean(self, axis=None, dtype=None):
"""See `theano.tensor.mean`"""
return mean(self, axis)
return mean(self, axis=axis, dtype=dtype)
def var(self, axis=None):
"""See `theano.tensor.var`"""
......@@ -2630,9 +2635,13 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor
def prod(input, axis = None):
"""WRITEME"""
return elemwise.Prod(axis)(input)
def prod(input, axis=None, dtype=None):
"""
Returns the Product of a tensor's elements along the given axis(es).
For full documentation see ``tensor.elemwise.Prod``.
"""
return elemwise.Prod(axis, dtype=dtype)(input)
class Mean(elemwise.CAReduce):
def __init__(self, axis = None):
......@@ -2667,35 +2676,54 @@ class Mean(elemwise.CAReduce):
# return grad(mean(x, self.axis, op=False),[x])
@constructor
def mean(input, axis = None, op = False):
def mean(input, axis=None, dtype=None, op=False):
"""Compute the mean value along the given axis of a tensor `input`
:param axis: compute the mean along this axis of the tensor.
None means all axes (like numpy).
:type axis: None or int or (list of int) (see `Sum`)
:note: for gpu, if you manually cast the input to float32 before calling
mean, everything will be done on the gpu.
:param dtype: dtype to use for the inner summation. This will not
necessarily be the dtype of the output (in particular
if it is a discrete (int/uint) dtype, the output will
be in a float type)
:type dtype: string
:note: for gpu, if you specify dtype=float32, everything will be done
on the gpu.
"""
if op:
return Mean(axis)(input)
if str(input.dtype) in discrete_dtypes:
# we need to cast eventually anyway, and this helps
# to prevents overflow
input = cast(input, 'float64')
s = sum(input, axis)
if dtype is not None:
# The summation will be done with the specified dtype.
# sum() will complain if it is not suitable.
sum_dtype = dtype
elif input.dtype in discrete_dtypes:
# we need to cast eventually anyway, and this helps
# to prevents overflow. Numpy uses 'float64'.
# TODO: use floatX? let casting_policy decide?
sum_dtype = 'float64'
else:
# Let sum() infer the appropriate dtype
sum_dtype = None
s = sum(input, axis=axis, dtype=sum_dtype)
shp = shape(input)
if input.dtype == 'float32':
# Cast shp into a float type
if s.dtype in ('float32', 'complex64'):
shp = cast(shp, 'float32')
else:
shp = cast(shp, 'float64')
if axis is None:
axis = range(input.ndim)
elif isinstance(axis, int):
axis = [axis]
for i in axis:
s = s / shp[i]
if str(input.dtype).startswith('float'):
assert input.dtype == s.dtype
return s
@constructor
......
......@@ -29,6 +29,10 @@ def TensorVariable(*inputs, **kwargs):
def TensorConstant(*inputs, **kwargs):
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
# Define common subsets of dtypes (as strings).
discrete_dtypes = map(str, scalar.discrete_types)
continuous_dtypes = map(str, scalar.continuous_types)
##################
### DimShuffle ###
......@@ -1315,23 +1319,27 @@ class Any(CAReduce):
return "Any{%s}" % ", ".join(map(str, self.axis))
class Sum(CAReduce):
class CAReduceDtype(CAReduce):
"""
Sums all the values of a tensor along the specified axis(es).
Reduces a scalar operation along the specified axis(es).
Equivalent to CAReduce(scalar.add, axis=axis), with the
difference that this defines the gradient of sum wrt its tensor
input.
This subclass of CAReduce accepts an additional "dtype" parameter,
that specifies which dtype will be used for the accumulation.
If no dtype is provided, one will be inferred so as not to lose
too much precision.
"""
def __init__(self, axis=None, dtype=None):
def __init__(self, scalar_op, axis=None, dtype=None):
"""
Constructor.
Usage: CAReduceDtype(scalar_op, axis=None, dtype=None)
:param axis: Axis(es) along which the tensor should be summed
(use None to sum over all axes, and a list or tuple to sum along more
than one axis).
:param scalar_op: a binary scalar op with only one output.
It must be commutative and associative.
:axis: - the dimension along which we want to reduce
- list of dimensions that we want to reduce
- if None, all dimensions are reduced
:param dtype: The dtype of the internal accumulator and returned
tensor. If None, then we use the default dtype which is the same as the
......@@ -1344,14 +1352,8 @@ class Sum(CAReduce):
uses the default machine integer while we always use 64 bit integers to
avoid platform-dependent behavior).
IMPORTANT: If you use a custom dtype (!= None), it is strongly advised
to set `config.on_opt_error` to 'raise' and to run your code in
DebugMode at least once. This is because some optimizations may not
currently be able to properly deal with such custom dtypes. Also please
note that using a custom dtype may prevent some optimizations from
being applied.
"""
CAReduce.__init__(self, scalar.add, axis)
CAReduce.__init__(self, scalar_op, axis=axis)
self.dtype = dtype
def __eq__(self, other):
......@@ -1361,7 +1363,9 @@ class Sum(CAReduce):
return CAReduce.__hash__(self) ^ hash(self.dtype)
def _output_dtype(self, idtype):
if self.dtype is None:
dtype = self.dtype
if dtype is None:
# If input has an discrete dtype, upcast it to 64
return dict(
int8='int64',
int16='int64',
......@@ -1370,8 +1374,65 @@ class Sum(CAReduce):
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
# Specifying a continuous output for discrete input is OK
return dtype
else:
# The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build %s node with input dtype %s '
'and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the reduce operation.'
% (self, idtype, dtype, upcasted_dtype, dtype))
return dtype
def make_node(self, input):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
dtype = self._output_dtype(input.dtype)
assert dtype is not None
if dtype == self.dtype:
# Don't build another instance
op = self
else:
return self.dtype
op = self.__class__(axis=self.axis, dtype=dtype)
return CAReduce.make_node(op, input)
class Sum(CAReduceDtype):
"""
Sums all the values of a tensor along the specified axis(es).
Equivalent to CAReduceDtype(scalar.add, axis=axis, dtype=dtype),
with the difference that this defines the gradient of sum wrt its
tensor input.
"""
def __init__(self, axis=None, dtype=None):
"""
Constructor.
:param axis: Axis(es) along which the tensor should be summed
(use None to sum over all axes, and a list or tuple to sum along more
than one axis).
:param dtype: The dtype of the internal accumulator and returned
tensor. If None, then we use the default dtype which is the same as the
input tensor's dtype except when:
- the input dtype is a signed integer of precision < 64 bit, in
which case we use int64
- the input dtype is an unsigned integer of precision < 64 bit, in
which case we use uint64
"""
CAReduceDtype.__init__(self, scalar.add, axis=axis, dtype=dtype)
def grad(self, inp, grads):
x, = inp
......@@ -1407,7 +1468,7 @@ class Sum(CAReduce):
return "Sum{%s}" % ", ".join(map(str, self.axis))
class Prod(CAReduce):
class Prod(CAReduceDtype):
"""
Multiplies all the values of a tensor along the specified axis(es).
......@@ -1415,9 +1476,8 @@ class Prod(CAReduce):
difference that this defines the gradient of prod wrt its tensor
input.
"""
def __init__(self, axis=None, no_zeros_in_input=False):
CAReduce.__init__(self, scalar.mul, axis)
def __init__(self, axis=None, dtype=None, no_zeros_in_input=False):
CAReduceDtype.__init__(self, scalar.mul, axis=axis, dtype=dtype)
self.no_zeros_in_input = no_zeros_in_input
def __setstate__(self, dct):
......@@ -1427,24 +1487,12 @@ class Prod(CAReduce):
self.no_zeros_in_input = False
def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.axis == other.axis and self.no_zeros_in_input == other.no_zeros_in_input
return (CAReduceDtype.__eq__(self, other)
and self.no_zeros_in_input == other.no_zeros_in_input)
def __hash__(self):
if self.axis is None:
return hash(self.scalar_op) ^ hash(self.no_zeros_in_input)
else:
return hash(self.scalar_op) ^ hash(tuple(self.axis)) ^ hash(self.no_zeros_in_input)
def _output_dtype(self, idtype):
# we want to protect against overflow
return dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
return (CAReduceDtype.__hash__(self) ^
hash(self.no_zeros_in_input))
def grad(self, inp, grads):
'''
......@@ -1596,20 +1644,9 @@ class MulWithoutZeros(scalar.BinaryScalarOp):
return (1,)
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros')
class ProdWithoutZeros(CAReduce):
def __init__(self, axis = None):
CAReduce.__init__(self, mul_without_zeros, axis)
def _output_dtype(self, idtype):
# we want to protect against overflow
return dict(
int8='int32',
int16='int32',
int32='int64',
uint8='uint32',
uint16='uint32',
uint32='uint64',
).get(idtype, idtype)
class ProdWithoutZeros(CAReduceDtype):
def __init__(self, axis=None, dtype=None):
CAReduceDtype.__init__(self, mul_without_zeros, axis=axis, dtype=dtype)
def __str__(self):
if self.axis is None:
......
......@@ -2875,8 +2875,7 @@ def local_sum_mul_by_scalar(node):
"""
# TODO: if the the thing inside the Sum is a division,
# we should get at the numerator....
# TODO Implement for sum.dtype != None.
if isinstance(node.op, T.Sum) and node.op.dtype is None:
if isinstance(node.op, T.Sum):
thing_summed, = node.inputs
if thing_summed.owner and thing_summed.owner.op == T.mul:
terms = thing_summed.owner.inputs
......@@ -2930,9 +2929,7 @@ def local_sum_div_dimshuffle(node):
# dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation.
# TODO Implement for sum.dtype != None.
if isinstance(node.op, T.Sum) and node.op.dtype is None:
if isinstance(node.op, T.Sum):
axis = node.op.axis
if axis is None:
axis = range(node.inputs[0].ndim)
......@@ -3029,24 +3026,21 @@ def local_sum_all_to_none(node):
def local_sum_sum(node):
"""
Sum(Sum()) -> Sum
Note that currently we only replace sums with default dtypes, to avoid
potential dtype conflict issues.
"""
if isinstance(node.op, T.Sum) and node.op.dtype is None:
if isinstance(node.op, T.Sum):
summed, = node.inputs
out_dtype = node.op.dtype
if len(summed.clients) == 1:
if (summed.owner and
isinstance(summed.owner.op, T.Sum)
and summed.owner.op.dtype is None):
isinstance(summed.owner.op, T.Sum)):
if summed.owner.op.axis is None:
# special case of local_cut_useless_reduce
return [T.Sum(None)(summed.owner.inputs[0])]
return [T.Sum(None, dtype=out_dtype)(summed.owner.inputs[0])]
if node.op.axis is None:
# we're summing up everything anyway so lets
# do it all at once
return [T.Sum(None)(summed.owner.inputs[0])]
return [T.Sum(None, dtype=out_dtype)(summed.owner.inputs[0])]
newaxis = list(tuple(summed.owner.op.axis))
# figure out which dimensions of the original input
......@@ -3089,7 +3083,7 @@ def local_sum_sum(node):
"been fixed) set the theano flag "
"`warn.sum_sum_bug` to False.")
combined_sum = T.Sum(newaxis)
combined_sum = T.Sum(newaxis, dtype=out_dtype)
return [combined_sum(summed.owner.inputs[0])]
......@@ -3108,8 +3102,7 @@ def local_cut_useless_reduce(node):
@gof.local_optimizer([])
def local_sum_alloc(node):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes)"""
# TODO Implement for sum.dtype != None
if isinstance(node.op, T.Sum) and node.op.dtype is None:
if isinstance(node.op, T.Sum):
summed, = node.inputs
if summed.owner and isinstance(summed.owner.op, T.Alloc):
input = summed.owner.inputs[0]
......
......@@ -3151,6 +3151,20 @@ class T_local_sum(unittest.TestCase):
finally:
config.on_opt_error = backup
def test_local_sum_sum_dtype(self):
"""
Test that local_sum_sum works when specifying dtypes manually.
"""
x = tensor.tensor3(dtype='int8')
y = x.sum(axis=0, dtype='int32').sum(axis=1, dtype='int64')
backup = config.on_opt_error
config.on_opt_error = 'raise'
try:
# This compilation would fail prior to fix.
f = theano.function([x], y)
finally:
config.on_opt_error = backup
class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论