提交 41747183 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Factor changes into a CAReduceDtype Op.

上级 9c55d4aa
...@@ -1319,23 +1319,27 @@ class Any(CAReduce): ...@@ -1319,23 +1319,27 @@ class Any(CAReduce):
return "Any{%s}" % ", ".join(map(str, self.axis)) return "Any{%s}" % ", ".join(map(str, self.axis))
class Sum(CAReduce): class CAReduceDtype(CAReduce):
""" """
Sums all the values of a tensor along the specified axis(es). Reduces a scalar operation along the specified axis(es).
Equivalent to CAReduce(scalar.add, axis=axis), with the This subclass of CAReduce accepts an additional "dtype" parameter,
difference that this defines the gradient of sum wrt its tensor that specifies which dtype will be used for the accumulation.
input.
If no dtype is provided, one will be inferred so as not to lose
too much precision.
""" """
def __init__(self, axis=None, dtype=None): def __init__(self, scalar_op, axis=None, dtype=None):
""" """
Constructor. Usage: CAReduceDtype(scalar_op, axis=None, dtype=None)
:param axis: Axis(es) along which the tensor should be summed :param scalar_op: a binary scalar op with only one output.
(use None to sum over all axes, and a list or tuple to sum along more It must be commutative and associative.
than one axis).
:axis: - the dimension along which we want to reduce
- list of dimensions that we want to reduce
- if None, all dimensions are reduced
:param dtype: The dtype of the internal accumulator and returned :param dtype: The dtype of the internal accumulator and returned
tensor. If None, then we use the default dtype which is the same as the tensor. If None, then we use the default dtype which is the same as the
...@@ -1348,14 +1352,8 @@ class Sum(CAReduce): ...@@ -1348,14 +1352,8 @@ class Sum(CAReduce):
uses the default machine integer while we always use 64 bit integers to uses the default machine integer while we always use 64 bit integers to
avoid platform-dependent behavior). avoid platform-dependent behavior).
IMPORTANT: If you use a custom dtype (!= None), it is strongly advised
to set `config.on_opt_error` to 'raise' and to run your code in
DebugMode at least once. This is because some optimizations may not
currently be able to properly deal with such custom dtypes. Also please
note that using a custom dtype may prevent some optimizations from
being applied.
""" """
CAReduce.__init__(self, scalar.add, axis) CAReduce.__init__(self, scalar_op, axis=axis)
self.dtype = dtype self.dtype = dtype
def __eq__(self, other): def __eq__(self, other):
...@@ -1384,15 +1382,15 @@ class Sum(CAReduce): ...@@ -1384,15 +1382,15 @@ class Sum(CAReduce):
upcasted_dtype = scalar.upcast(idtype, dtype) upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype: if dtype != upcasted_dtype:
raise TypeError( raise TypeError(
'Cannot build Sum node with input dtype %s ' 'Cannot build %s node with input dtype %s '
'and output dtype %s, as precision would be lost. ' 'and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n' 'To correct this error, you can either:\n'
' - not specify a dtype, or\n' ' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n' ' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can ' 'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your ' 'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the sum.' 'input, or on the output of the reduce operation.'
% (idtype, dtype, upcasted_dtype, dtype)) % (self, idtype, dtype, upcasted_dtype, dtype))
return dtype return dtype
def make_node(self, input): def make_node(self, input):
...@@ -1408,6 +1406,34 @@ class Sum(CAReduce): ...@@ -1408,6 +1406,34 @@ class Sum(CAReduce):
op = self.__class__(axis=self.axis, dtype=dtype) op = self.__class__(axis=self.axis, dtype=dtype)
return CAReduce.make_node(op, input) return CAReduce.make_node(op, input)
class Sum(CAReduceDtype):
"""
Sums all the values of a tensor along the specified axis(es).
Equivalent to CAReduceDtype(scalar.add, axis=axis, dtype=dtype),
with the difference that this defines the gradient of sum wrt its
tensor input.
"""
def __init__(self, axis=None, dtype=None):
"""
Constructor.
:param axis: Axis(es) along which the tensor should be summed
(use None to sum over all axes, and a list or tuple to sum along more
than one axis).
:param dtype: The dtype of the internal accumulator and returned
tensor. If None, then we use the default dtype which is the same as the
input tensor's dtype except when:
- the input dtype is a signed integer of precision < 64 bit, in
which case we use int64
- the input dtype is an unsigned integer of precision < 64 bit, in
which case we use uint64
"""
CAReduceDtype.__init__(self, scalar.add, axis=axis, dtype=dtype)
def grad(self, inp, grads): def grad(self, inp, grads):
x, = inp x, = inp
gz, = grads gz, = grads
...@@ -1442,7 +1468,7 @@ class Sum(CAReduce): ...@@ -1442,7 +1468,7 @@ class Sum(CAReduce):
return "Sum{%s}" % ", ".join(map(str, self.axis)) return "Sum{%s}" % ", ".join(map(str, self.axis))
class Prod(CAReduce): class Prod(CAReduceDtype):
""" """
Multiplies all the values of a tensor along the specified axis(es). Multiplies all the values of a tensor along the specified axis(es).
...@@ -1451,8 +1477,7 @@ class Prod(CAReduce): ...@@ -1451,8 +1477,7 @@ class Prod(CAReduce):
input. input.
""" """
def __init__(self, axis=None, dtype=None, no_zeros_in_input=False): def __init__(self, axis=None, dtype=None, no_zeros_in_input=False):
CAReduce.__init__(self, scalar.mul, axis) CAReduceDtype.__init__(self, scalar.mul, axis=axis, dtype=dtype)
self.dtype = dtype
self.no_zeros_in_input = no_zeros_in_input self.no_zeros_in_input = no_zeros_in_input
def __setstate__(self, dct): def __setstate__(self, dct):
...@@ -1462,59 +1487,12 @@ class Prod(CAReduce): ...@@ -1462,59 +1487,12 @@ class Prod(CAReduce):
self.no_zeros_in_input = False self.no_zeros_in_input = False
def __eq__(self, other): def __eq__(self, other):
return (type(self) == type(other) and return (CAReduceDtype.__eq__(self, other)
self.scalar_op == other.scalar_op
and self.axis == other.axis
and self.dtype == other.dtype
and self.no_zeros_in_input == other.no_zeros_in_input) and self.no_zeros_in_input == other.no_zeros_in_input)
def __hash__(self): def __hash__(self):
return (CAReduce.__hash__(self) ^ return (CAReduceDtype.__hash__(self) ^
hash(self.no_zeros_in_input) ^ hash(self.no_zeros_in_input))
hash(self.dtype))
def _output_dtype(self, idtype):
dtype = self.dtype
if dtype is None:
# we want to protect against overflow
return dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
# Specifying a continuous output for discrete input is OK
return dtype
else:
# The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build Prod node with input dtype %s '
'and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the prod.'
% (idtype, dtype, upcasted_dtype, dtype))
return dtype
def make_node(self, input):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
dtype = self._output_dtype(input.dtype)
if dtype == self.dtype:
# Don't build another instance
op = self
else:
op = self.__class__(axis=self.axis, dtype=dtype)
return CAReduce.make_node(op, input)
def grad(self, inp, grads): def grad(self, inp, grads):
''' '''
...@@ -1666,47 +1644,9 @@ class MulWithoutZeros(scalar.BinaryScalarOp): ...@@ -1666,47 +1644,9 @@ class MulWithoutZeros(scalar.BinaryScalarOp):
return (1,) return (1,)
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros') mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros')
class ProdWithoutZeros(CAReduce): class ProdWithoutZeros(CAReduceDtype):
def __init__(self, axis=None, dtype=None): def __init__(self, axis=None, dtype=None):
CAReduce.__init__(self, mul_without_zeros, axis) CAReduceDtype.__init__(self, mul_without_zeros, axis=axis, dtype=dtype)
self.dtype = dtype
def __eq__(self, other):
return CAReduce.__eq__(self, other) and self.dtype == other.dtype
def __hash__(self):
return CAReduce.__hash__(self) ^ hash(self.dtype)
def _output_dtype(self, idtype):
dtype = self.dtype
if dtype is None:
# we want to protect against overflow
return dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
# Specifying a continuous output for discrete input is OK
return dtype
else:
# The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build ProdWithoutZeros node with input dtype '
' %s and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the prod_without_zeros.'
% (idtype, dtype, upcasted_dtype, dtype))
return dtype
def make_node(self, input): def make_node(self, input):
# We need to redefine make_node so that, if self.dtype is None, # We need to redefine make_node so that, if self.dtype is None,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论