提交 563e4adf authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add dtype keyword to tensor.Prod.

Also add ".prod()" method to tensor variables.
上级 1005832b
...@@ -1422,6 +1422,10 @@ class _tensor_py_operators: ...@@ -1422,6 +1422,10 @@ class _tensor_py_operators:
"""See `theano.tensor.sum`""" """See `theano.tensor.sum`"""
return sum(self, axis=axis, dtype=dtype) return sum(self, axis=axis, dtype=dtype)
def prod(self, axis=None, dtype=None):
"""See `theano.tensor.prod`"""
return prod(self, axis=axis, dtype=dtype)
def norm(self, L, axis=None): def norm(self, L, axis=None):
if L==0: if L==0:
raise NotImplementedError() raise NotImplementedError()
...@@ -2631,9 +2635,13 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum')) ...@@ -2631,9 +2635,13 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor @constructor
def prod(input, axis = None): def prod(input, axis=None, dtype=None):
"""WRITEME""" """
return elemwise.Prod(axis)(input) Returns the Product of a tensor's elements along the given axis(es).
For full documentation see ``tensor.elemwise.Prod``.
"""
return elemwise.Prod(axis, dtype=dtype)(input)
class Mean(elemwise.CAReduce): class Mean(elemwise.CAReduce):
def __init__(self, axis = None): def __init__(self, axis = None):
......
...@@ -1450,9 +1450,9 @@ class Prod(CAReduce): ...@@ -1450,9 +1450,9 @@ class Prod(CAReduce):
difference that this defines the gradient of prod wrt its tensor difference that this defines the gradient of prod wrt its tensor
input. input.
""" """
def __init__(self, axis=None, no_zeros_in_input=False): def __init__(self, axis=None, dtype=None, no_zeros_in_input=False):
CAReduce.__init__(self, scalar.mul, axis) CAReduce.__init__(self, scalar.mul, axis)
self.dtype = dtype
self.no_zeros_in_input = no_zeros_in_input self.no_zeros_in_input = no_zeros_in_input
def __setstate__(self, dct): def __setstate__(self, dct):
...@@ -1462,24 +1462,59 @@ class Prod(CAReduce): ...@@ -1462,24 +1462,59 @@ class Prod(CAReduce):
self.no_zeros_in_input = False self.no_zeros_in_input = False
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.axis == other.axis and self.no_zeros_in_input == other.no_zeros_in_input return (type(self) == type(other) and
self.scalar_op == other.scalar_op
and self.axis == other.axis
and self.dtype == other.dtype
and self.no_zeros_in_input == other.no_zeros_in_input)
def __hash__(self): def __hash__(self):
if self.axis is None: return (CAReduce.__hash__(self) ^
return hash(self.scalar_op) ^ hash(self.no_zeros_in_input) hash(self.no_zeros_in_input) ^
else: hash(self.dtype))
return hash(self.scalar_op) ^ hash(tuple(self.axis)) ^ hash(self.no_zeros_in_input)
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
# we want to protect against overflow dtype = self.dtype
return dict( if dtype is None:
int8='int64', # we want to protect against overflow
int16='int64', return dict(
int32='int64', int8='int64',
uint8='uint64', int16='int64',
uint16='uint64', int32='int64',
uint32='uint64', uint8='uint64',
).get(idtype, idtype) uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
# Specifying a continuous output for discrete input is OK
return dtype
else:
# The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build Prod node with input dtype %s '
'and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the prod.'
% (idtype, dtype, upcasted_dtype, dtype))
return dtype
def make_node(self, input):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
dtype = self._output_dtype(input.dtype)
if dtype == self.dtype:
# Don't build another instance
op = self
else:
op = self.__class__(axis=self.axis, dtype=dtype)
return CAReduce.make_node(op, input)
def grad(self, inp, grads): def grad(self, inp, grads):
''' '''
...@@ -1632,19 +1667,58 @@ class MulWithoutZeros(scalar.BinaryScalarOp): ...@@ -1632,19 +1667,58 @@ class MulWithoutZeros(scalar.BinaryScalarOp):
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros') mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros')
class ProdWithoutZeros(CAReduce): class ProdWithoutZeros(CAReduce):
def __init__(self, axis = None): def __init__(self, axis=None, dtype=None):
CAReduce.__init__(self, mul_without_zeros, axis) CAReduce.__init__(self, mul_without_zeros, axis)
self.dtype = dtype
def __eq__(self, other):
return CAReduce.__eq__(self, other) and self.dtype == other.dtype
def __hash__(self):
return CAReduce.__hash__(self) ^ hash(self.dtype)
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
# we want to protect against overflow dtype = self.dtype
return dict( if dtype is None:
int8='int32', # we want to protect against overflow
int16='int32', return dict(
int32='int64', int8='int64',
uint8='uint32', int16='int64',
uint16='uint32', int32='int64',
uint32='uint64', uint8='uint64',
).get(idtype, idtype) uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
# Specifying a continuous output for discrete input is OK
return dtype
else:
# The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build ProdWithoutZeros node with input dtype '
' %s and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the prod_without_zeros.'
% (idtype, dtype, upcasted_dtype, dtype))
return dtype
def make_node(self, input):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
dtype = self._output_dtype(input.dtype)
if dtype == self.dtype:
# Don't build another instance
op = self
else:
op = self.__class__(axis=self.axis, dtype=dtype)
return CAReduce.make_node(op, input)
def __str__(self): def __str__(self):
if self.axis is None: if self.axis is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论