提交 6a7aa408 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

dtype of tensor.sum(): forbid downcasting

If the output dtype would force a downcasting of the values (for instance, summing a float vector into an integer), a TypeError is now raised. Also, a Sum node has to have a non-None dtype. It is checked in make_node because for some reason the pretty-printing mechanism needs to instanciate "Sum()".
上级 41103b5d
......@@ -29,6 +29,10 @@ def TensorVariable(*inputs, **kwargs):
def TensorConstant(*inputs, **kwargs):
raise Exception("Circular dependencies prevent using this here. import tensor before elemwise")
# Define common subsets of dtypes (as strings).
discrete_dtypes = map(str, scalar.discrete_types)
continuous_dtypes = map(str, scalar.continuous_types)
##################
### DimShuffle ###
......@@ -1361,7 +1365,9 @@ class Sum(CAReduce):
return CAReduce.__hash__(self) ^ hash(self.dtype)
def _output_dtype(self, idtype):
if self.dtype is None:
dtype = self.dtype
if dtype is None:
# If input has an discrete dtype, upcast it to 64
return dict(
int8='int64',
int16='int64',
......@@ -1370,8 +1376,37 @@ class Sum(CAReduce):
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
# Specifying a continuous output for discrete input is OK
return dtype
else:
# The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build Sum node with input dtype %s '
'and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the sum.'
% (idtype, dtype, upcasted_dtype, dtype))
return dtype
def make_node(self, input):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
dtype = self._output_dtype(input.dtype)
assert dtype is not None
if dtype == self.dtype:
# Don't build another instance
op = self
else:
return self.dtype
op = self.__class__(axis=self.axis, dtype=dtype)
return CAReduce.make_node(op, input)
def grad(self, inp, grads):
x, = inp
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论