提交 41747183 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Factor changes into a CAReduceDtype Op.

上级 9c55d4aa
......@@ -1319,23 +1319,27 @@ class Any(CAReduce):
return "Any{%s}" % ", ".join(map(str, self.axis))
class Sum(CAReduce):
class CAReduceDtype(CAReduce):
"""
Sums all the values of a tensor along the specified axis(es).
Reduces a scalar operation along the specified axis(es).
Equivalent to CAReduce(scalar.add, axis=axis), with the
difference that this defines the gradient of sum wrt its tensor
input.
This subclass of CAReduce accepts an additional "dtype" parameter,
that specifies which dtype will be used for the accumulation.
If no dtype is provided, one will be inferred so as not to lose
too much precision.
"""
def __init__(self, axis=None, dtype=None):
def __init__(self, scalar_op, axis=None, dtype=None):
"""
Constructor.
Usage: CAReduceDtype(scalar_op, axis=None, dtype=None)
:param axis: Axis(es) along which the tensor should be summed
(use None to sum over all axes, and a list or tuple to sum along more
than one axis).
:param scalar_op: a binary scalar op with only one output.
It must be commutative and associative.
:axis: - the dimension along which we want to reduce
- list of dimensions that we want to reduce
- if None, all dimensions are reduced
:param dtype: The dtype of the internal accumulator and returned
tensor. If None, then we use the default dtype which is the same as the
......@@ -1348,14 +1352,8 @@ class Sum(CAReduce):
uses the default machine integer while we always use 64 bit integers to
avoid platform-dependent behavior).
IMPORTANT: If you use a custom dtype (!= None), it is strongly advised
to set `config.on_opt_error` to 'raise' and to run your code in
DebugMode at least once. This is because some optimizations may not
currently be able to properly deal with such custom dtypes. Also please
note that using a custom dtype may prevent some optimizations from
being applied.
"""
CAReduce.__init__(self, scalar.add, axis)
CAReduce.__init__(self, scalar_op, axis=axis)
self.dtype = dtype
def __eq__(self, other):
......@@ -1384,15 +1382,15 @@ class Sum(CAReduce):
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build Sum node with input dtype %s '
'Cannot build %s node with input dtype %s '
'and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the sum.'
% (idtype, dtype, upcasted_dtype, dtype))
'input, or on the output of the reduce operation.'
% (self, idtype, dtype, upcasted_dtype, dtype))
return dtype
def make_node(self, input):
......@@ -1408,6 +1406,34 @@ class Sum(CAReduce):
op = self.__class__(axis=self.axis, dtype=dtype)
return CAReduce.make_node(op, input)
class Sum(CAReduceDtype):
"""
Sums all the values of a tensor along the specified axis(es).
Equivalent to CAReduceDtype(scalar.add, axis=axis, dtype=dtype),
with the difference that this defines the gradient of sum wrt its
tensor input.
"""
def __init__(self, axis=None, dtype=None):
"""
Constructor.
:param axis: Axis(es) along which the tensor should be summed
(use None to sum over all axes, and a list or tuple to sum along more
than one axis).
:param dtype: The dtype of the internal accumulator and returned
tensor. If None, then we use the default dtype which is the same as the
input tensor's dtype except when:
- the input dtype is a signed integer of precision < 64 bit, in
which case we use int64
- the input dtype is an unsigned integer of precision < 64 bit, in
which case we use uint64
"""
CAReduceDtype.__init__(self, scalar.add, axis=axis, dtype=dtype)
def grad(self, inp, grads):
x, = inp
gz, = grads
......@@ -1442,7 +1468,7 @@ class Sum(CAReduce):
return "Sum{%s}" % ", ".join(map(str, self.axis))
class Prod(CAReduce):
class Prod(CAReduceDtype):
"""
Multiplies all the values of a tensor along the specified axis(es).
......@@ -1451,8 +1477,7 @@ class Prod(CAReduce):
input.
"""
def __init__(self, axis=None, dtype=None, no_zeros_in_input=False):
CAReduce.__init__(self, scalar.mul, axis)
self.dtype = dtype
CAReduceDtype.__init__(self, scalar.mul, axis=axis, dtype=dtype)
self.no_zeros_in_input = no_zeros_in_input
def __setstate__(self, dct):
......@@ -1462,59 +1487,12 @@ class Prod(CAReduce):
self.no_zeros_in_input = False
def __eq__(self, other):
return (type(self) == type(other) and
self.scalar_op == other.scalar_op
and self.axis == other.axis
and self.dtype == other.dtype
return (CAReduceDtype.__eq__(self, other)
and self.no_zeros_in_input == other.no_zeros_in_input)
def __hash__(self):
return (CAReduce.__hash__(self) ^
hash(self.no_zeros_in_input) ^
hash(self.dtype))
def _output_dtype(self, idtype):
dtype = self.dtype
if dtype is None:
# we want to protect against overflow
return dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
# Specifying a continuous output for discrete input is OK
return dtype
else:
# The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build Prod node with input dtype %s '
'and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the prod.'
% (idtype, dtype, upcasted_dtype, dtype))
return dtype
def make_node(self, input):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
dtype = self._output_dtype(input.dtype)
if dtype == self.dtype:
# Don't build another instance
op = self
else:
op = self.__class__(axis=self.axis, dtype=dtype)
return CAReduce.make_node(op, input)
return (CAReduceDtype.__hash__(self) ^
hash(self.no_zeros_in_input))
def grad(self, inp, grads):
'''
......@@ -1666,47 +1644,9 @@ class MulWithoutZeros(scalar.BinaryScalarOp):
return (1,)
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros')
class ProdWithoutZeros(CAReduce):
class ProdWithoutZeros(CAReduceDtype):
def __init__(self, axis=None, dtype=None):
CAReduce.__init__(self, mul_without_zeros, axis)
self.dtype = dtype
def __eq__(self, other):
return CAReduce.__eq__(self, other) and self.dtype == other.dtype
def __hash__(self):
return CAReduce.__hash__(self) ^ hash(self.dtype)
def _output_dtype(self, idtype):
dtype = self.dtype
if dtype is None:
# we want to protect against overflow
return dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
# Specifying a continuous output for discrete input is OK
return dtype
else:
# The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build ProdWithoutZeros node with input dtype '
' %s and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the prod_without_zeros.'
% (idtype, dtype, upcasted_dtype, dtype))
return dtype
CAReduceDtype.__init__(self, mul_without_zeros, axis=axis, dtype=dtype)
def make_node(self, input):
# We need to redefine make_node so that, if self.dtype is None,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论