提交 563e4adf authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Add dtype keyword to tensor.Prod.

Also add ".prod()" method to tensor variables.
上级 1005832b
......@@ -1422,6 +1422,10 @@ class _tensor_py_operators:
"""See `theano.tensor.sum`"""
return sum(self, axis=axis, dtype=dtype)
def prod(self, axis=None, dtype=None):
"""See `theano.tensor.prod`"""
return prod(self, axis=axis, dtype=dtype)
def norm(self, L, axis=None):
if L==0:
raise NotImplementedError()
......@@ -2631,9 +2635,13 @@ pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor
def prod(input, axis = None):
"""WRITEME"""
return elemwise.Prod(axis)(input)
def prod(input, axis=None, dtype=None):
"""
Returns the Product of a tensor's elements along the given axis(es).
For full documentation see ``tensor.elemwise.Prod``.
"""
return elemwise.Prod(axis, dtype=dtype)(input)
class Mean(elemwise.CAReduce):
def __init__(self, axis = None):
......
......@@ -1450,9 +1450,9 @@ class Prod(CAReduce):
difference that this defines the gradient of prod wrt its tensor
input.
"""
def __init__(self, axis=None, no_zeros_in_input=False):
def __init__(self, axis=None, dtype=None, no_zeros_in_input=False):
CAReduce.__init__(self, scalar.mul, axis)
self.dtype = dtype
self.no_zeros_in_input = no_zeros_in_input
def __setstate__(self, dct):
......@@ -1462,24 +1462,59 @@ class Prod(CAReduce):
self.no_zeros_in_input = False
def __eq__(self, other):
return type(self) == type(other) and self.scalar_op == other.scalar_op and self.axis == other.axis and self.no_zeros_in_input == other.no_zeros_in_input
return (type(self) == type(other) and
self.scalar_op == other.scalar_op
and self.axis == other.axis
and self.dtype == other.dtype
and self.no_zeros_in_input == other.no_zeros_in_input)
def __hash__(self):
if self.axis is None:
return hash(self.scalar_op) ^ hash(self.no_zeros_in_input)
else:
return hash(self.scalar_op) ^ hash(tuple(self.axis)) ^ hash(self.no_zeros_in_input)
return (CAReduce.__hash__(self) ^
hash(self.no_zeros_in_input) ^
hash(self.dtype))
def _output_dtype(self, idtype):
# we want to protect against overflow
return dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
dtype = self.dtype
if dtype is None:
# we want to protect against overflow
return dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
# Specifying a continuous output for discrete input is OK
return dtype
else:
# The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build Prod node with input dtype %s '
'and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the prod.'
% (idtype, dtype, upcasted_dtype, dtype))
return dtype
def make_node(self, input):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
dtype = self._output_dtype(input.dtype)
if dtype == self.dtype:
# Don't build another instance
op = self
else:
op = self.__class__(axis=self.axis, dtype=dtype)
return CAReduce.make_node(op, input)
def grad(self, inp, grads):
'''
......@@ -1632,19 +1667,58 @@ class MulWithoutZeros(scalar.BinaryScalarOp):
mul_without_zeros = MulWithoutZeros(scalar.upcast_out, name = 'mul_without_zeros')
class ProdWithoutZeros(CAReduce):
def __init__(self, axis = None):
def __init__(self, axis=None, dtype=None):
CAReduce.__init__(self, mul_without_zeros, axis)
self.dtype = dtype
def __eq__(self, other):
return CAReduce.__eq__(self, other) and self.dtype == other.dtype
def __hash__(self):
return CAReduce.__hash__(self) ^ hash(self.dtype)
def _output_dtype(self, idtype):
# we want to protect against overflow
return dict(
int8='int32',
int16='int32',
int32='int64',
uint8='uint32',
uint16='uint32',
uint32='uint64',
).get(idtype, idtype)
dtype = self.dtype
if dtype is None:
# we want to protect against overflow
return dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
elif dtype in continuous_dtypes and idtype in discrete_dtypes:
# Specifying a continuous output for discrete input is OK
return dtype
else:
# The conversion has to be considered an upcast.
upcasted_dtype = scalar.upcast(idtype, dtype)
if dtype != upcasted_dtype:
raise TypeError(
'Cannot build ProdWithoutZeros node with input dtype '
' %s and output dtype %s, as precision would be lost. '
'To correct this error, you can either:\n'
' - not specify a dtype, or\n'
' - use a dtype at least as precise as %s.\n'
'If you are expecting the precision loss, you can '
'use tensor.cast(..., dtype="%s"), either on your '
'input, or on the output of the prod_without_zeros.'
% (idtype, dtype, upcasted_dtype, dtype))
return dtype
def make_node(self, input):
# We need to redefine make_node so that, if self.dtype is None,
# we can infer what dtype should be, and create a node from an Op
# of the appropriate dtype.
dtype = self._output_dtype(input.dtype)
if dtype == self.dtype:
# Don't build another instance
op = self
else:
op = self.__class__(axis=self.axis, dtype=dtype)
return CAReduce.make_node(op, input)
def __str__(self):
if self.axis is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论