提交 8d11831d authored 作者: gdesjardins's avatar gdesjardins

Added new Prod op which multiplies the contents of a tensor along a given axis.

Added tensor.prod function to go with it.
上级 03375db6
...@@ -1666,6 +1666,10 @@ def sum(input, axis = None): ...@@ -1666,6 +1666,10 @@ def sum(input, axis = None):
pprint.assign(Sum(), printing.FunctionPrinter('sum')) pprint.assign(Sum(), printing.FunctionPrinter('sum'))
@constructor
def prod(input, axis = None):
"""WRITEME"""
return elemwise.Prod(axis)(input)
@constructor @constructor
def mean(input, axis = None): def mean(input, axis = None):
......
...@@ -933,6 +933,9 @@ class Sum(CAReduce): ...@@ -933,6 +933,9 @@ class Sum(CAReduce):
int8='int32', int8='int32',
int16='int32', int16='int32',
int32='int64', int32='int64',
uint8='uint32',
uint16='uint32',
uint32='uint64',
).get(idtype, idtype) ).get(idtype, idtype)
def grad(self, (x, ), (gz, )): def grad(self, (x, ), (gz, )):
...@@ -958,4 +961,38 @@ class Sum(CAReduce): ...@@ -958,4 +961,38 @@ class Sum(CAReduce):
else: else:
return "Sum{%s}" % ", ".join(map(str, self.axis)) return "Sum{%s}" % ", ".join(map(str, self.axis))
class Prod(CAReduce):
"""
Multiplies all the values of a tensor along the specified axis(es).
Equivalent to CAReduce(scalar.prod, axis = axis), with the
difference that this defines the gradient of prod wrt its tensor
input.
"""
def __init__(self, axis = None):
CAReduce.__init__(self, scalar.mul, axis)
def _output_dtype(self, idtype):
# we want to protect against overflow
return dict(
int8='int64',
int16='int64',
int32='int64',
uint8='uint64',
uint16='uint64',
uint32='uint64',
).get(idtype, idtype)
def grad(self, (x, ), (gz, )):
if x.dtype[0:3] in ('int','uin'):
return [None]
else:
raise NotImplementedError('Will be implemented shortly')
def __str__(self):
if self.axis is None:
return "Prod"
else:
return "Prod{%s}" % ", ".join(map(str, self.axis))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论