提交 01ab8831 authored 作者: nouiz's avatar nouiz

Merge pull request #402 from lamblin/tensor_careduce_dtype

dtype for different Tensor CAReduce operations
...@@ -1418,8 +1418,13 @@ class _tensor_py_operators: ...@@ -1418,8 +1418,13 @@ class _tensor_py_operators:
def __rdot__(right, left): def __rdot__(right, left):
return dot(left, right) return dot(left, right)
def sum(self, *args, **kw): def sum(self, axis=None, dtype=None):
return sum(self, *args, **kw) """See `theano.tensor.sum`"""
return sum(self, axis=axis, dtype=dtype)
def prod(self, axis=None, dtype=None):
"""See `theano.tensor.prod`"""
return prod(self, axis=axis, dtype=dtype)
def norm(self, L, axis=None): def norm(self, L, axis=None):
if L==0: if L==0:
...@@ -1429,9 +1434,9 @@ class _tensor_py_operators: ...@@ -1429,9 +1434,9 @@ class _tensor_py_operators:
#optimizations will/should catch cases like L=1, L=2 #optimizations will/should catch cases like L=1, L=2
return pow(pow(abs_(self), L).sum(axis=axis), 1.0/L) return pow(pow(abs_(self), L).sum(axis=axis), 1.0/L)
def mean(self, axis=None): def mean(self, axis=None, dtype=None):
"""See `theano.tensor.mean`""" """See `theano.tensor.mean`"""
return mean(self, axis) return mean(self, axis=axis, dtype=dtype)
def var(self, axis=None): def var(self, axis=None):
"""See `theano.tensor.var`""" """See `theano.tensor.var`"""
...@@ -2630,9 +2635,13 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum')) ...@@ -2630,9 +2635,13 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor @constructor
def prod(input, axis = None): def prod(input, axis=None, dtype=None):
"""WRITEME""" """
return elemwise.Prod(axis)(input) Returns the Product of a tensor's elements along the given axis(es).
For full documentation see ``tensor.elemwise.Prod``.
"""
return elemwise.Prod(axis, dtype=dtype)(input)
class Mean(elemwise.CAReduce): class Mean(elemwise.CAReduce):
def __init__(self, axis = None): def __init__(self, axis = None):
...@@ -2667,35 +2676,54 @@ class Mean(elemwise.CAReduce): ...@@ -2667,35 +2676,54 @@ class Mean(elemwise.CAReduce):
# return grad(mean(x, self.axis, op=False),[x]) # return grad(mean(x, self.axis, op=False),[x])
@constructor @constructor
def mean(input, axis = None, op = False): def mean(input, axis=None, dtype=None, op=False):
"""Compute the mean value along the given axis of a tensor `input` """Compute the mean value along the given axis of a tensor `input`
:param axis: compute the mean along this axis of the tensor. :param axis: compute the mean along this axis of the tensor.
None means all axes (like numpy). None means all axes (like numpy).
:type axis: None or int or (list of int) (see `Sum`) :type axis: None or int or (list of int) (see `Sum`)
:note: for gpu, if you manually cast the input to float32 before calling :param dtype: dtype to use for the inner summation. This will not
mean, everything will be done on the gpu. necessarily be the dtype of the output (in particular
if it is a discrete (int/uint) dtype, the output will
be in a float type)
:type dtype: string
:note: for gpu, if you specify dtype=float32, everything will be done
on the gpu.
""" """
if op: if op:
return Mean(axis)(input) return Mean(axis)(input)
if str(input.dtype) in discrete_dtypes: if dtype is not None:
# The summation will be done with the specified dtype.
# sum() will complain if it is not suitable.
sum_dtype = dtype
elif input.dtype in discrete_dtypes:
# we need to cast eventually anyway, and this helps # we need to cast eventually anyway, and this helps
# to prevents overflow # to prevents overflow. Numpy uses 'float64'.
input = cast(input, 'float64') # TODO: use floatX? let casting_policy decide?
s = sum(input, axis) sum_dtype = 'float64'
else:
# Let sum() infer the appropriate dtype
sum_dtype = None
s = sum(input, axis=axis, dtype=sum_dtype)
shp = shape(input) shp = shape(input)
if input.dtype == 'float32':
# Cast shp into a float type
if s.dtype in ('float32', 'complex64'):
shp = cast(shp, 'float32') shp = cast(shp, 'float32')
else:
shp = cast(shp, 'float64')
if axis is None: if axis is None:
axis = range(input.ndim) axis = range(input.ndim)
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] axis = [axis]
for i in axis: for i in axis:
s = s / shp[i] s = s / shp[i]
if str(input.dtype).startswith('float'):
assert input.dtype == s.dtype
return s return s
@constructor @constructor
......
...@@ -29,6 +29,10 @@ def TensorVariable(*inputs, **kwargs): ...@@ -29,6 +29,10 @@ def TensorVariable(*inputs, **kwargs):
def TensorConstant(*inputs, **kwargs): def TensorConstant(*inputs, **kwargs):
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise") raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
# Define common subsets of dtypes (as strings).
discrete_dtypes = map(str, scalar.discrete_types)
continuous_dtypes = map(str, scalar.continuous_types)
################## ##################
### DimShuffle ### ### DimShuffle ###
...@@ -1315,23 +1319,27 @@ class Any(CAReduce): ...@@ -1315,23 +1319,27 @@ class Any(CAReduce):
return "Any{%s}" % ", ".join(map(str, self.axis)) return "Any{%s}" % ", ".join(map(str, self.axis))
class Sum(CAReduce): class CAReduceDtype(CAReduce):
""" """
Sums all the values of a tensor along the specified axis(es). Reduces a scalar operation along the specified axis(es).
Equivalent to CAReduce(scalar.add, axis=axis), with the This subclass of CAReduce accepts an additional "dtype" parameter,
difference that this defines the gradient of sum wrt its tensor that specifies which dtype will be used for the accumulation.
input.
If no dtype is provided, one will be inferred so as not to lose
too much precision.
""" """
def __init__(self, axis=None, dtype=None): def __init__(self, scalar_op, axis=None, dtype=None):
""" """
Constructor. Usage: CAReduceDtype(scalar_op, axis=None, dtype=None)
:param axis: Axis(es) along which the tensor should be summed :param scalar_op: a binary scalar op with only one output.
(use None to sum over all axes, and a list or tuple to sum along more It must be commutative and associative.
than one axis).
:axis: - the dimension along which we want to reduce
- list of dimensions that we want to reduce
- if None, all dimensions are reduced
:param dtype: The dtype of the internal accumulator and returned :param dtype: The dtype of the internal accumulator and returned
tensor. If None, then we use the default dtype which is the same as the tensor. If None, then we use the default dtype which is the same as the
...@@ -1344,14 +1352,8 @@ class Sum(CAReduce): ...@@ -1344,14 +1352,8 @@ class Sum(CAReduce):
uses the default machine integer while we always use 64 bit integers to uses the default machine integer while we always use 64 bit integers to
avoid platform-dependent behavior). avoid platform-dependent behavior).
IMPORTANT: If you use a custom dtype (!= None), it is strongly advised
to set `config.on_opt_error` to 'raise' and to run your code in
DebugMode at least once. This is because some optimizations may not
currently be able to properly deal with such custom dtypes. Also please
note that using a custom dtype may prevent some optimizations from
being applied.
""" """
CAReduce.__init__(self, scalar.add, axis) CAReduce.__init__(self, scalar_op, axis=axis)
self.dtype = dtype self.dtype = dtype
def __eq__(self, other): def __eq__(self, other):
...@@ -1361,7 +1363,9 @@ class Sum(CAReduce): ...@@ -1361,7 +1363,9 @@ class Sum(CAReduce):
return CAReduce.__hash__(self) ^ hash(self.dtype) return CAReduce.__hash__(self) ^ hash(self.dtype)
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
if self.dtype is None: dtype = self.dtype
if dtype is None:
# If input has an discrete dtype, upcast it to 64
return dict( return dict(
int8='int64', int8='int64',
int16='int64', int16='int64',
...@@ -1370,8 +1374,65 @@ class Sum(CAReduce): ...@@ -1370,8 +1374,65 @@ class Sum(CAReduce):
uint16='uint64', uint16='uint64',
uint32='uint64', uint32='uint64',
).get(idtype, idtype) ).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
# Specifying a continuous output for discrete input is OK
return dtype
else: else:
return self.dtype # The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build %s node with input dtype %s '
'and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the reduce operation.'
% (self, idtype, dtype, upcasted_dtype, dtype))
return dtype
def make_node(self, input):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
dtype = self._output_dtype(input.dtype)
assert dtype is not None
if dtype == self.dtype:
# Don't build another instance
op = self
else:
op = self.__class__(axis=self.axis, dtype=dtype)
return CAReduce.make_node(op, input)
class Sum(CAReduceDtype):
"""
Sums all the values of a tensor along the specified axis(es).
Equivalent to CAReduceDtype(scalar.add, axis=axis, dtype=dtype),
with the difference that this defines the gradient of sum wrt its
tensor input.
"""
def __init__(self, axis=None, dtype=None):
"""
Constructor.
:param axis: Axis(es) along which the tensor should be summed
(use None to sum over all axes, and a list or tuple to sum along more
than one axis).
:param dtype: The dtype of the internal accumulator and returned
tensor. If None, then we use the default dtype which is the same as the
input tensor's dtype except when:
- the input dtype is a signed integer of precision < 64 bit, in
which case we use int64
- the input dtype is an unsigned integer of precision < 64 bit, in
which case we use uint64
"""
CAReduceDtype.__init__(self, scalar.add, axis=axis, dtype=dtype)
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
...@@ -1407,7 +1468,7 @@ class Sum(CAReduce): ...@@ -1407,7 +1468,7 @@ class Sum(CAReduce):
return "Sum{%s}" % ", ".join(map(str, self.axis)) return "Sum{%s}" % ", ".join(map(str, self.axis))
class Prod(CAReduce): class Prod(CAReduceDtype):
""" """
Multiplies all the values of a tensor along the specified axis(es). Multiplies all the values of a tensor along the specified axis(es).
...@@ -1415,9 +1476,8 @@ class Prod(CAReduce): ...@@ -1415,9 +1476,8 @@ class Prod(CAReduce):
difference that this defines the gradient of prod wrt its tensor difference that this defines the gradient of prod wrt its tensor
input. input.
""" """
def __init__(self, axis=None, no_zeros_in_input=False): def __init__(self, axis=None, dtype=None, no_zeros_in_input=False):
CAReduce.__init__(self, scalar.mul, axis) CAReduceDtype.__init__(self, scalar.mul, axis=axis, dtype=dtype)
self.no_zeros_in_input = no_zeros_in_input self.no_zeros_in_input = no_zeros_in_input
def __setstate__(self, dct): def __setstate__(self, dct):
...@@ -1427,24 +1487,12 @@ class Prod(CAReduce): ...@@ -1427,24 +1487,12 @@ class Prod(CAReduce):
self.no_zeros_in_input = False self.no_zeros_in_input = False
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.axis == other.axis and self.no_zeros_in_input == other.no_zeros_in_input return (CAReduceDtype.__eq__(self, other)
and self.no_zeros_in_input == other.no_zeros_in_input)
def __hash__(self): def __hash__(self):
if self.axis is None: return (CAReduceDtype.__hash__(self) ^
return hash(self.scalar_op) ^ hash(self.no_zeros_in_input) hash(self.no_zeros_in_input))
else:
return hash(self.scalar_op) ^ hash(tuple(self.axis)) ^ hash(self.no_zeros_in_input)
def _output_dtype(self, idtype):
# we want to protect against overflow
return dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
def grad(self, inp, grads): def grad(self, inp, grads):
''' '''
...@@ -1596,20 +1644,9 @@ class MulWithoutZeros(scalar.BinaryScalarOp): ...@@ -1596,20 +1644,9 @@ class MulWithoutZeros(scalar.BinaryScalarOp):
return (1,) return (1,)
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros') mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros')
class ProdWithoutZeros(CAReduce): class ProdWithoutZeros(CAReduceDtype):
def __init__(self, axis = None): def __init__(self, axis=None, dtype=None):
CAReduce.__init__(self, mul_without_zeros, axis) CAReduceDtype.__init__(self, mul_without_zeros, axis=axis, dtype=dtype)
def _output_dtype(self, idtype):
# we want to protect against overflow
return dict(
int8='int32',
int16='int32',
int32='int64',
uint8='uint32',
uint16='uint32',
uint32='uint64',
).get(idtype, idtype)
def __str__(self): def __str__(self):
if self.axis is None: if self.axis is None:
......
...@@ -2875,8 +2875,7 @@ def local_sum_mul_by_scalar(node): ...@@ -2875,8 +2875,7 @@ def local_sum_mul_by_scalar(node):
""" """
# TODO: if the the thing inside the Sum is a division, # TODO: if the the thing inside the Sum is a division,
# we should get at the numerator.... # we should get at the numerator....
# TODO Implement for sum.dtype != None. if isinstance(node.op, T.Sum):
if isinstance(node.op, T.Sum) and node.op.dtype is None:
thing_summed, = node.inputs thing_summed, = node.inputs
if thing_summed.owner and thing_summed.owner.op == T.mul: if thing_summed.owner and thing_summed.owner.op == T.mul:
terms = thing_summed.owner.inputs terms = thing_summed.owner.inputs
...@@ -2930,9 +2929,7 @@ def local_sum_div_dimshuffle(node): ...@@ -2930,9 +2929,7 @@ def local_sum_div_dimshuffle(node):
# dimshuffle is in the numerator, since elemwise inversion of the # dimshuffle is in the numerator, since elemwise inversion of the
# denominator would still be needed before the summation. # denominator would still be needed before the summation.
# TODO Implement for sum.dtype != None. if isinstance(node.op, T.Sum):
if isinstance(node.op, T.Sum) and node.op.dtype is None:
axis = node.op.axis axis = node.op.axis
if axis is None: if axis is None:
axis = range(node.inputs[0].ndim) axis = range(node.inputs[0].ndim)
...@@ -3029,24 +3026,21 @@ def local_sum_all_to_none(node): ...@@ -3029,24 +3026,21 @@ def local_sum_all_to_none(node):
def local_sum_sum(node): def local_sum_sum(node):
""" """
Sum(Sum()) -> Sum Sum(Sum()) -> Sum
Note that currently we only replace sums with default dtypes, to avoid
potential dtype conflict issues.
""" """
if isinstance(node.op, T.Sum) and node.op.dtype is None: if isinstance(node.op, T.Sum):
summed, = node.inputs summed, = node.inputs
out_dtype = node.op.dtype
if len(summed.clients) == 1: if len(summed.clients) == 1:
if (summed.owner and if (summed.owner and
isinstance(summed.owner.op, T.Sum) isinstance(summed.owner.op, T.Sum)):
and summed.owner.op.dtype is None):
if summed.owner.op.axis is None: if summed.owner.op.axis is None:
# special case of local_cut_useless_reduce # special case of local_cut_useless_reduce
return [T.Sum(None)(summed.owner.inputs[0])] return [T.Sum(None, dtype=out_dtype)(summed.owner.inputs[0])]
if node.op.axis is None: if node.op.axis is None:
# we're summing up everything anyway so lets # we're summing up everything anyway so lets
# do it all at once # do it all at once
return [T.Sum(None)(summed.owner.inputs[0])] return [T.Sum(None, dtype=out_dtype)(summed.owner.inputs[0])]
newaxis = list(tuple(summed.owner.op.axis)) newaxis = list(tuple(summed.owner.op.axis))
# figure out which dimensions of the original input # figure out which dimensions of the original input
...@@ -3089,7 +3083,7 @@ def local_sum_sum(node): ...@@ -3089,7 +3083,7 @@ def local_sum_sum(node):
"been fixed) set the theano flag " "been fixed) set the theano flag "
"`warn.sum_sum_bug` to False.") "`warn.sum_sum_bug` to False.")
combined_sum = T.Sum(newaxis) combined_sum = T.Sum(newaxis, dtype=out_dtype)
return [combined_sum(summed.owner.inputs[0])] return [combined_sum(summed.owner.inputs[0])]
...@@ -3108,8 +3102,7 @@ def local_cut_useless_reduce(node): ...@@ -3108,8 +3102,7 @@ def local_cut_useless_reduce(node):
@gof.local_optimizer([]) @gof.local_optimizer([])
def local_sum_alloc(node): def local_sum_alloc(node):
""" sum(alloc(constant,shapes...)) => constant*prod(shapes)""" """ sum(alloc(constant,shapes...)) => constant*prod(shapes)"""
# TODO Implement for sum.dtype != None if isinstance(node.op, T.Sum):
if isinstance(node.op, T.Sum) and node.op.dtype is None:
summed, = node.inputs summed, = node.inputs
if summed.owner and isinstance(summed.owner.op, T.Alloc): if summed.owner and isinstance(summed.owner.op, T.Alloc):
input = summed.owner.inputs[0] input = summed.owner.inputs[0]
......
...@@ -13,6 +13,7 @@ from theano.compile.mode import get_default_mode ...@@ -13,6 +13,7 @@ from theano.compile.mode import get_default_mode
from theano.tensor.elemwise import * from theano.tensor.elemwise import *
from theano.tests import unittest_tools from theano.tests import unittest_tools
complex_dtypes = map(str, scalar.complex_types)
def Env(i, o): def Env(i, o):
e = gof.Env(i, o) e = gof.Env(i, o)
...@@ -499,7 +500,8 @@ class test_IsInf_IsNan(unittest.TestCase): ...@@ -499,7 +500,8 @@ class test_IsInf_IsNan(unittest.TestCase):
return self.run_isfunc('isnan') return self.run_isfunc('isnan')
def test_sum_default_dtype(): class T_sum_dtype(unittest.TestCase):
def test_sum_default_dtype(self):
""" """
Test the default dtype of a sum(). Test the default dtype of a sum().
""" """
...@@ -517,8 +519,7 @@ def test_sum_default_dtype(): ...@@ -517,8 +519,7 @@ def test_sum_default_dtype():
uint32='uint64', uint32='uint64',
).get(dtype, dtype) ).get(dtype, dtype)
def test_sum_custom_dtype(self):
def test_sum_custom_dtype():
""" """
Test the ability to provide your own output dtype for a sum. Test the ability to provide your own output dtype for a sum.
""" """
...@@ -529,9 +530,172 @@ def test_sum_custom_dtype(): ...@@ -529,9 +530,172 @@ def test_sum_custom_dtype():
x = tensor.matrix(dtype=input_dtype) x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types): for output_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)] axis = axes[idx % len(axes)]
assert x.sum(dtype=output_dtype, axis=axis).dtype == output_dtype # If output_dtype would force a downcast, we expect a TypeError
# We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and
output_dtype in continuous_dtypes)
):
sum_var = x.sum(dtype=output_dtype, axis=axis)
assert sum_var.dtype == output_dtype
# Check that we can take the gradient
grad_var = tensor.grad(sum_var.sum(), x,
disconnected_inputs='ignore')
else:
self.assertRaises(TypeError,
x.sum, dtype=output_dtype, axis=axis)
idx += 1 idx += 1
class T_mean_dtype(unittest.TestCase):
def test_mean_default_dtype(self):
"""
Test the default dtype of a mean().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).mean(axis=axis)
if dtype in discrete_dtypes:
assert x.dtype == 'float64'
else:
assert x.dtype == dtype, (x, x.dtype, dtype)
def test_mean_custom_dtype(self):
"""
Test the ability to provide your own output dtype for a mean.
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
idx = 0
for input_dtype in imap(str, theano.scalar.all_types):
x = tensor.matrix(dtype=input_dtype)
for sum_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)]
# If the inner sum cannot be created, it will raise a TypeError.
try:
mean_var = x.mean(dtype=sum_dtype, axis=axis)
except TypeError:
pass
else:
# Executed if no TypeError was raised
if sum_dtype in discrete_dtypes:
assert mean_var.dtype == 'float64', (mean_var.dtype, sum_dtype)
else:
assert mean_var.dtype == sum_dtype, (mean_var.dtype, output_dtype)
# Check that we can take the gradient, when implemented
try:
grad_var = tensor.grad(mean_var.sum(), x,
disconnected_inputs='ignore')
except NotImplementedError:
# TrueDiv does not seem to have a gradient when
# the numerator is complex.
if mean_var.dtype in complex_dtypes:
pass
else:
raise
idx += 1
class T_prod_dtype(unittest.TestCase):
def test_prod_default_dtype(self):
"""
Test the default dtype of a prod().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)]
x = tensor.matrix(dtype=dtype).prod(axis=axis)
assert x.dtype == dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(dtype, dtype)
def test_prod_custom_dtype(self):
"""
Test the ability to provide your own output dtype for a prod.
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
idx = 0
for input_dtype in imap(str, theano.scalar.all_types):
x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)]
# If output_dtype would force a downcast, we expect a TypeError
# We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and
output_dtype in continuous_dtypes)
):
prod_var = x.prod(dtype=output_dtype, axis=axis)
assert prod_var.dtype == output_dtype
# Check that we can take the gradient
grad_var = tensor.grad(prod_var.sum(), x,
disconnected_inputs='ignore')
else:
self.assertRaises(TypeError,
x.prod, dtype=output_dtype, axis=axis)
idx += 1
class T_prod_without_zeros_dtype(unittest.TestCase):
def test_prod_without_zeros_default_dtype(self):
"""
Test the default dtype of a ProdWithoutZeros().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
for idx, dtype in enumerate(imap(str, theano.scalar.all_types)):
axis = axes[idx % len(axes)]
x = ProdWithoutZeros(axis=axis)(tensor.matrix(dtype=dtype))
assert x.dtype == dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(dtype, dtype)
def test_prod_without_zeros_custom_dtype(self):
"""
Test the ability to provide your own output dtype for a ProdWithoutZeros().
"""
# We try multiple axis combinations even though axis should not matter.
axes = [None, 0, 1, [0], [1], [0, 1]]
idx = 0
for input_dtype in imap(str, theano.scalar.all_types):
x = tensor.matrix(dtype=input_dtype)
for output_dtype in imap(str, theano.scalar.all_types):
axis = axes[idx % len(axes)]
# If output_dtype would force a downcast, we expect a TypeError
# We always allow int/uint inputs with float/complex outputs.
upcasted_dtype = scalar.upcast(input_dtype, output_dtype)
if (output_dtype == upcasted_dtype or
(input_dtype in discrete_dtypes and
output_dtype in continuous_dtypes)
):
prod_woz_var = ProdWithoutZeros(
axis=axis, dtype=output_dtype)(x)
assert prod_woz_var.dtype == output_dtype
else:
self.assertRaises(TypeError,
ProdWithoutZeros(axis=axis, dtype=output_dtype),
x)
idx += 1
if __name__ == '__main__': if __name__ == '__main__':
#unittest.main() #unittest.main()
......
...@@ -3151,6 +3151,20 @@ class T_local_sum(unittest.TestCase): ...@@ -3151,6 +3151,20 @@ class T_local_sum(unittest.TestCase):
finally: finally:
config.on_opt_error = backup config.on_opt_error = backup
def test_local_sum_sum_dtype(self):
"""
Test that local_sum_sum works when specifying dtypes manually.
"""
x = tensor.tensor3(dtype='int8')
y = x.sum(axis=0, dtype='int32').sum(axis=1, dtype='int64')
backup = config.on_opt_error
config.on_opt_error = 'raise'
try:
# This compilation would fail prior to fix.
f = theano.function([x], y)
finally:
config.on_opt_error = backup
class T_local_sum_dimshuffle(unittest.TestCase): class T_local_sum_dimshuffle(unittest.TestCase):
def setUp(self): def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论