提交 33ade0c3 authored 作者: Frederic Bastien's avatar Frederic Bastien

fix CAReduce compilation with uint* and test it with uint* and complex*.

上级 530f513b
......@@ -1056,6 +1056,8 @@ class CAReduce(Op):
scal_name = 'maximum'
if input.type.dtype in ["float32","float64"]:
identity = "-__builtin_inf()"
elif input.type.dtype.startswith("uint"):
identity = "0"
else:
identity = "NPY_MIN_"+str(input.type.dtype).upper()
if self.scalar_op == scalar.minimum:
......
......@@ -195,7 +195,8 @@ class test_CAReduce(unittest.TestCase):
if tosum is None: tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh))
if dtype.startswith('float'):
if not "int" in dtype:
xv = numpy.asarray(xv,dtype=dtype)
else:
xv = numpy.asarray(xv<0.5,dtype=dtype)
......@@ -245,7 +246,8 @@ class test_CAReduce(unittest.TestCase):
raise Exception("Test for CAReduce with scalar_op %s not implemented"%str(scalar_op))
if scalar_op in [maximum,minimum] and numpy_raised:
try:
f(xv)
out = f(xv)
assert out.dtype == dtype
except ValueError:
pass
else:
......@@ -254,7 +256,7 @@ class test_CAReduce(unittest.TestCase):
#numpy.{all,any} return bool type.
if scalar_op in [and_, or_]:
zv = numpy.asarray(zv, dtype=dtype)
self.assertTrue((numpy.abs(f(xv) - zv) < 1e-10).all())
self.assertTrue(numpy.allclose(f(xv), zv))
#test CAReduce.infer_shape
......@@ -268,22 +270,27 @@ class test_CAReduce(unittest.TestCase):
assert all(f(xv)== zv.shape)
def test_perform(self):
self.with_linker(gof.PerformLinker(), add)
self.with_linker(gof.PerformLinker(), mul)
self.with_linker(gof.PerformLinker(), maximum)
self.with_linker(gof.PerformLinker(), minimum)
self.with_linker(gof.PerformLinker(), or_, dtype='int8')
self.with_linker(gof.PerformLinker(), and_, dtype='int8')
self.with_linker(gof.PerformLinker(), xor, dtype='int8')
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_linker(gof.PerformLinker(), add, dtype=dtype)
self.with_linker(gof.PerformLinker(), mul, dtype=dtype)
self.with_linker(gof.PerformLinker(), maximum, dtype=dtype)
self.with_linker(gof.PerformLinker(), minimum, dtype=dtype)
for dtype in ["int8", "uint8"]:
self.with_linker(gof.PerformLinker(), or_, dtype=dtype)
self.with_linker(gof.PerformLinker(), and_, dtype=dtype)
self.with_linker(gof.PerformLinker(), xor, dtype=dtype)
def test_c(self):
self.with_linker(gof.CLinker(), add)
self.with_linker(gof.CLinker(), mul)
self.with_linker(gof.CLinker(), maximum)
self.with_linker(gof.CLinker(), minimum)
self.with_linker(gof.CLinker(), or_, dtype='int8')
self.with_linker(gof.CLinker(), and_, dtype='int8')
self.with_linker(gof.CLinker(), xor, dtype='int8')
for dtype in ["floatX", "complex64", "complex128", "int8", "uint8"]:
self.with_linker(gof.CLinker(), add, dtype=dtype)
self.with_linker(gof.CLinker(), mul, dtype=dtype)
for dtype in ["floatX", "int8", "uint8"]:
self.with_linker(gof.CLinker(), minimum, dtype=dtype)
self.with_linker(gof.CLinker(), maximum, dtype=dtype)
for dtype in ["int8", "uint8"]:
self.with_linker(gof.CLinker(), or_, dtype=dtype)
self.with_linker(gof.CLinker(), and_, dtype=dtype)
self.with_linker(gof.CLinker(), xor, dtype=dtype)
class test_Prod(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论