提交 06477d5a authored 作者: Frederic Bastien's avatar Frederic Bastien

make CAReduce support scalar.minimum.

上级 dcbb5804
......@@ -998,11 +998,19 @@ class CAReduce(Op):
if hasattr(self.scalar_op,'identity'):
identity = self.scalar_op.identity
elif self.scalar_op == scalar.maximum:
if input.type.dtype in ["float32","float64"]:
identity = "-__builtin_inf()"
else:
identity = "NPY_MIN_"+str(input.type.dtype).upper()
elif self.scalar_op in [scalar.maximum, scalar.minimum]:
if self.scalar_op == scalar.maximum:
if input.type.dtype in ["float32","float64"]:
identity = "-__builtin_inf()"
scal_name = 'maximum'
else:
identity = "NPY_MIN_"+str(input.type.dtype).upper()
scal_name = 'minimum'
if self.scalar_op == scalar.minimum:
if input.type.dtype in ["float32","float64"]:
identity = "__builtin_inf()"
else:
identity = "NPY_MAX_"+str(input.type.dtype).upper()
fail = sub["fail"]
pattern=[0]*len(node.inputs[0].broadcastable)
axis = self.axis
......@@ -1014,7 +1022,7 @@ class CAReduce(Op):
alloc += """
for(int i=0;i<%(iname)s->nd;i++){
if(PyArray_DIMS(%(iname)s)[i]==0 && tosum[i]){
PyErr_Format(PyExc_ValueError, "Input of CAReduce{maximum} has zero-size on axis %%d",i);
PyErr_Format(PyExc_ValueError, "Input of CAReduce{%(scal_name)s} has zero-size on axis %%d",i);
%(fail)s;
}
}
......
......@@ -195,6 +195,12 @@ class test_CAReduce(unittest.TestCase):
zv = numpy.maximum.reduce(zv, axis)
except ValueError:
numpy_raised=True
elif scalar_op == minimum:
try:
for axis in reversed(sorted(tosum)):
zv = numpy.minimum.reduce(zv, axis)
except ValueError:
numpy_raised=True
elif scalar_op == or_:
for axis in reversed(sorted(tosum)):
zv = numpy.any(zv, axis)
......@@ -203,7 +209,7 @@ class test_CAReduce(unittest.TestCase):
zv = numpy.all(zv, axis)
else:
raise Exception("Test for CAReduce with scalar_op %s not implemented"%str(scalar_op))
if scalar_op == maximum and numpy_raised:
if scalar_op in [maximum,minimum] and numpy_raised:
try:
f(xv)
except ValueError:
......@@ -221,13 +227,14 @@ class test_CAReduce(unittest.TestCase):
e = CAReduce(scalar_op, axis = tosum)(x)
if tosum is None: tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e.shape])).make_function()
if not(scalar_op == maximum and ((xsh==() or numpy.prod(xsh)==0))):
if not(scalar_op in [maximum,minimum] and ((xsh==() or numpy.prod(xsh)==0))):
assert all(f(xv)== zv.shape)
def test_perform(self):
self.with_linker(gof.PerformLinker(), add)
self.with_linker(gof.PerformLinker(), mul)
self.with_linker(gof.PerformLinker(), maximum)
self.with_linker(gof.PerformLinker(), minimum)
#need other dtype then real
#self.with_linker(gof.PerformLinker(), or_)
#self.with_linker(gof.PerformLinker(), and_)
......@@ -236,6 +243,7 @@ class test_CAReduce(unittest.TestCase):
self.with_linker(gof.CLinker(), add)
self.with_linker(gof.CLinker(), mul)
self.with_linker(gof.CLinker(), maximum)
self.with_linker(gof.CLinker(), minimum)
#need other dtype then real
#no c_code for or_, and_
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论