提交 4befd3b9 authored 作者: Frederic Bastien's avatar Frederic Bastien

Test CAReduce with 'or', 'and' and 'xor' logical operation.

上级 90c06f14
...@@ -172,7 +172,7 @@ class test_CAReduce(unittest.TestCase): ...@@ -172,7 +172,7 @@ class test_CAReduce(unittest.TestCase):
def setUp(self): def setUp(self):
unittest_tools.seed_rng() unittest_tools.seed_rng()
def with_linker(self, linker, scalar_op = add): def with_linker(self, linker, scalar_op = add, dtype="floatX"):
for xsh, tosum in [((5, 6), None), for xsh, tosum in [((5, 6), None),
((5, 6), (0, 1)), ((5, 6), (0, 1)),
((5, 6), (0, )), ((5, 6), (0, )),
...@@ -188,11 +188,17 @@ class test_CAReduce(unittest.TestCase): ...@@ -188,11 +188,17 @@ class test_CAReduce(unittest.TestCase):
((5, 0), ()), ((5, 0), ()),
((), None), ((), None),
((), ())]: ((), ())]:
x = TensorType('float64', [(entry == 1) for entry in xsh])('x') if dtype == "floatX":
dtype = theano.config.floatX
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
e = CAReduce(scalar_op, axis = tosum)(x) e = CAReduce(scalar_op, axis = tosum)(x)
if tosum is None: tosum = range(len(xsh)) if tosum is None: tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e])).make_function() f = copy(linker).accept(Env([x], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh)) xv = numpy.asarray(numpy.random.rand(*xsh))
if dtype.startswith('float'):
xv = numpy.asarray(xv,dtype=dtype)
else:
xv = numpy.asarray(xv<0.5,dtype=dtype)
zv = xv zv = xv
numpy_raised = False numpy_raised = False
if len(tosum)>1 and any([a<0 for a in tosum]): if len(tosum)>1 and any([a<0 for a in tosum]):
...@@ -224,10 +230,17 @@ class test_CAReduce(unittest.TestCase): ...@@ -224,10 +230,17 @@ class test_CAReduce(unittest.TestCase):
numpy_raised=True numpy_raised=True
elif scalar_op == or_: elif scalar_op == or_:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.any(zv, axis) zv = numpy.bitwise_or.reduce(zv, axis)
elif scalar_op == and_: elif scalar_op == and_:
for axis in reversed(sorted(tosum)): for axis in reversed(sorted(tosum)):
zv = numpy.all(zv, axis) zv = numpy.bitwise_and.reduce(zv, axis)
elif scalar_op == xor:
# There is no identity value for the xor function
# So we can't support shape of dimensions 0.
if numpy.prod(zv.shape)==0:
continue
for axis in reversed(sorted(tosum)):
zv = numpy.bitwise_xor.reduce(zv, axis)
else: else:
raise Exception("Test for CAReduce with scalar_op %s not implemented"%str(scalar_op)) raise Exception("Test for CAReduce with scalar_op %s not implemented"%str(scalar_op))
if scalar_op in [maximum,minimum] and numpy_raised: if scalar_op in [maximum,minimum] and numpy_raised:
...@@ -238,13 +251,16 @@ class test_CAReduce(unittest.TestCase): ...@@ -238,13 +251,16 @@ class test_CAReduce(unittest.TestCase):
else: else:
self.fail() self.fail()
else: else:
#numpy.{all,any} return bool type.
if scalar_op in [and_, or_]:
zv = numpy.asarray(zv, dtype=dtype)
self.assertTrue((numpy.abs(f(xv) - zv) < 1e-10).all()) self.assertTrue((numpy.abs(f(xv) - zv) < 1e-10).all())
#test CAReduce.infer_shape #test CAReduce.infer_shape
#the Shape op don't implement c_code! #the Shape op don't implement c_code!
if isinstance(linker,gof.PerformLinker): if isinstance(linker,gof.PerformLinker):
x = TensorType('float64', [(entry == 1) for entry in xsh])('x') x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
e = CAReduce(scalar_op, axis = tosum)(x) e = CAReduce(scalar_op, axis = tosum)(x)
if tosum is None: tosum = range(len(xsh)) if tosum is None: tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e.shape])).make_function() f = copy(linker).accept(Env([x], [e.shape])).make_function()
...@@ -256,20 +272,18 @@ class test_CAReduce(unittest.TestCase): ...@@ -256,20 +272,18 @@ class test_CAReduce(unittest.TestCase):
self.with_linker(gof.PerformLinker(), mul) self.with_linker(gof.PerformLinker(), mul)
self.with_linker(gof.PerformLinker(), maximum) self.with_linker(gof.PerformLinker(), maximum)
self.with_linker(gof.PerformLinker(), minimum) self.with_linker(gof.PerformLinker(), minimum)
#need other dtype then real self.with_linker(gof.PerformLinker(), or_, dtype='int8')
#self.with_linker(gof.PerformLinker(), or_) self.with_linker(gof.PerformLinker(), and_, dtype='int8')
#self.with_linker(gof.PerformLinker(), and_) self.with_linker(gof.PerformLinker(), xor, dtype='int8')
def test_c(self): def test_c(self):
self.with_linker(gof.CLinker(), add) self.with_linker(gof.CLinker(), add)
self.with_linker(gof.CLinker(), mul) self.with_linker(gof.CLinker(), mul)
self.with_linker(gof.CLinker(), maximum) self.with_linker(gof.CLinker(), maximum)
self.with_linker(gof.CLinker(), minimum) self.with_linker(gof.CLinker(), minimum)
self.with_linker(gof.CLinker(), or_, dtype='int8')
#need other dtype then real self.with_linker(gof.CLinker(), and_, dtype='int8')
#no c_code for or_, and_ self.with_linker(gof.CLinker(), xor, dtype='int8')
#self.with_linker(gof.CLinker(), or_)
#self.with_linker(gof.CLinker(), and_)
class test_Prod(unittest.TestCase): class test_Prod(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论