提交 06477d5a authored 作者: Frederic Bastien's avatar Frederic Bastien

make CAReduce support scalar.minimum.

上级 dcbb5804
...@@ -998,11 +998,19 @@ class CAReduce(Op): ...@@ -998,11 +998,19 @@ class CAReduce(Op):
if hasattr(self.scalar_op,'identity'): if hasattr(self.scalar_op,'identity'):
identity = self.scalar_op.identity identity = self.scalar_op.identity
elif self.scalar_op == scalar.maximum: elif self.scalar_op in [scalar.maximum, scalar.minimum]:
if input.type.dtype in ["float32","float64"]: if self.scalar_op == scalar.maximum:
identity = "-__builtin_inf()" if input.type.dtype in ["float32","float64"]:
else: identity = "-__builtin_inf()"
identity = "NPY_MIN_"+str(input.type.dtype).upper() scal_name = 'maximum'
else:
identity = "NPY_MIN_"+str(input.type.dtype).upper()
scal_name = 'minimum'
if self.scalar_op == scalar.minimum:
if input.type.dtype in ["float32","float64"]:
identity = "__builtin_inf()"
else:
identity = "NPY_MAX_"+str(input.type.dtype).upper()
fail = sub["fail"] fail = sub["fail"]
pattern=[0]*len(node.inputs[0].broadcastable) pattern=[0]*len(node.inputs[0].broadcastable)
axis = self.axis axis = self.axis
...@@ -1014,7 +1022,7 @@ class CAReduce(Op): ...@@ -1014,7 +1022,7 @@ class CAReduce(Op):
alloc += """ alloc += """
for(int i=0;i<%(iname)s->nd;i++){ for(int i=0;i<%(iname)s->nd;i++){
if(PyArray_DIMS(%(iname)s)[i]==0 && tosum[i]){ if(PyArray_DIMS(%(iname)s)[i]==0 && tosum[i]){
PyErr_Format(PyExc_ValueError, "Input of CAReduce{maximum} has zero-size on axis %%d",i); PyErr_Format(PyExc_ValueError, "Input of CAReduce{%(scal_name)s} has zero-size on axis %%d",i);
%(fail)s; %(fail)s;
} }
} }
......
...@@ -195,6 +195,12 @@ class test_CAReduce(unittest.TestCase): ...@@ -195,6 +195,12 @@ class test_CAReduce(unittest.TestCase):
zv = numpy.maximum.reduce(zv, axis) zv = numpy.maximum.reduce(zv, axis)
except ValueError: except ValueError:
numpy_raised=True numpy_raised=True
elif scalar_op == minimum:
try:
for axis in reversed(sorted(tosum)):
zv = numpy.minimum.reduce(zv, axis)
except ValueError:
numpy_raised=True
elif scalar_op == or_: elif scalar_op == or_:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.any(zv, axis) zv = numpy.any(zv, axis)
...@@ -203,7 +209,7 @@ class test_CAReduce(unittest.TestCase): ...@@ -203,7 +209,7 @@ class test_CAReduce(unittest.TestCase):
zv = numpy.all(zv, axis) zv = numpy.all(zv, axis)
else: else:
raise Exception("Test for CAReduce with scalar_op %s not implemented"%str(scalar_op)) raise Exception("Test for CAReduce with scalar_op %s not implemented"%str(scalar_op))
if scalar_op == maximum and numpy_raised: if scalar_op in [maximum,minimum] and numpy_raised:
try: try:
f(xv) f(xv)
except ValueError: except ValueError:
...@@ -221,13 +227,14 @@ class test_CAReduce(unittest.TestCase): ...@@ -221,13 +227,14 @@ class test_CAReduce(unittest.TestCase):
e = CAReduce(scalar_op, axis = tosum)(x) e = CAReduce(scalar_op, axis = tosum)(x)
if tosum is None: tosum = range(len(xsh)) if tosum is None: tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e.shape])).make_function() f = copy(linker).accept(Env([x], [e.shape])).make_function()
if not(scalar_op == maximum and ((xsh==() or numpy.prod(xsh)==0))): if not(scalar_op in [maximum,minimum] and ((xsh==() or numpy.prod(xsh)==0))):
assert all(f(xv)== zv.shape) assert all(f(xv)== zv.shape)
def test_perform(self): def test_perform(self):
self.with_linker(gof.PerformLinker(), add) self.with_linker(gof.PerformLinker(), add)
self.with_linker(gof.PerformLinker(), mul) self.with_linker(gof.PerformLinker(), mul)
self.with_linker(gof.PerformLinker(), maximum) self.with_linker(gof.PerformLinker(), maximum)
self.with_linker(gof.PerformLinker(), minimum)
#need other dtype then real #need other dtype then real
#self.with_linker(gof.PerformLinker(), or_) #self.with_linker(gof.PerformLinker(), or_)
#self.with_linker(gof.PerformLinker(), and_) #self.with_linker(gof.PerformLinker(), and_)
...@@ -236,6 +243,7 @@ class test_CAReduce(unittest.TestCase): ...@@ -236,6 +243,7 @@ class test_CAReduce(unittest.TestCase):
self.with_linker(gof.CLinker(), add) self.with_linker(gof.CLinker(), add)
self.with_linker(gof.CLinker(), mul) self.with_linker(gof.CLinker(), mul)
self.with_linker(gof.CLinker(), maximum) self.with_linker(gof.CLinker(), maximum)
self.with_linker(gof.CLinker(), minimum)
#need other dtype then real #need other dtype then real
#no c_code for or_, and_ #no c_code for or_, and_
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论