提交 e8b091ba authored 作者: Frederic Bastien's avatar Frederic Bastien

added a disabled version of Mean op. That is to make the graph of compiled fct…

added a disabled version of Mean op. That is to make the graph of compiled fct easier to read instead of having multiple op for this fct. Need to make the gradient work before enabling by default.
上级 4f6c4303
......@@ -1738,8 +1738,36 @@ def prod(input, axis = None):
"""WRITEME"""
return elemwise.Prod(axis)(input)
class Mean(elemwise.CAReduce):
def __init__(self, axis = None):
elemwise.CAReduce.__init__(self, scal.add, axis)
def __str__(self):
if self.axis is not None:
return "Mean{%s}" % (", ".join(str(x) for x in self.axis))
else:
return "Mean"
def _output_dtype(self, idtype):
# we want to protect against overflow
return 'float64'
def perform(self, node, (input, ), (output, )):
ret = elemwise.CAReduce.perform(self,node,(input,),(output,))
output[0]=numpy.asarray(output[0]/len(input))
def c_code(self, node, name, inames, onames, sub):
ret = elemwise.CAReduce.c_code(self, node, name, inames, onames, sub)
return ret + """
*((double *)PyArray_DATA(%s)) /= PyArray_SIZE(%s);
"""%(onames[0],inames[0])
#TODO: implement the grad. When done and tested, you can make this the default version.
# def grad(self, (x,), (gout,)):
# import pdb;pdb.set_trace()
# return grad(mean(x, self.axis, op=False),[x])
@constructor
def mean(input, axis = None):
def mean(input, axis = None, op = False):
"""Compute the mean value along the given axis of a tensor `input`
:param axis: compute the mean along this axis of the tensor. None means all axes (like
......@@ -1747,6 +1775,9 @@ def mean(input, axis = None):
:type axis: None or int or (list of int) (see `Sum`)
"""
if op:
return Mean(axis)(input)
if str(input.dtype).startswith('int'):
# we need to cast eventually anyway, and this helps
# to prevents overflow
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论