提交 17912cb5 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merged

......@@ -851,9 +851,9 @@ class BinaryBitOp(BinaryScalarOp):
return [None, None]
class OR(BinaryBitOp):
identity = False
identity = 0
commutative = True
associative = False
associative = True
def impl(self, x, y):
return x | y
def c_code(self, node, name, (x, y), (z, ), sub):
......@@ -861,9 +861,9 @@ class OR(BinaryBitOp):
or_ = OR()
class XOR(BinaryBitOp):
identity = False
identity = 0
commutative = True
associative = False
associative = True
def impl(self, x, y):
return x ^ y
def c_code(self, node, name, (x, y), (z, ), sub):
......@@ -871,9 +871,9 @@ class XOR(BinaryBitOp):
xor = XOR()
class AND(BinaryBitOp):
identity = False
identity = 1
commutative = True
associative = False
associative = True
def impl(self, x, y):
return x & y
def c_code(self, node, name, (x, y), (z, ), sub):
......@@ -881,7 +881,6 @@ class AND(BinaryBitOp):
and_ = AND()
class Invert(UnaryBitOp):
identity = False
def impl(self, x):
return ~x
def c_code(self, node, name, (x,), (z, ), sub):
......
......@@ -172,7 +172,7 @@ class test_CAReduce(unittest.TestCase):
def setUp(self):
unittest_tools.seed_rng()
def with_linker(self, linker, scalar_op = add):
def with_linker(self, linker, scalar_op = add, dtype="floatX"):
for xsh, tosum in [((5, 6), None),
((5, 6), (0, 1)),
((5, 6), (0, )),
......@@ -188,11 +188,17 @@ class test_CAReduce(unittest.TestCase):
((5, 0), ()),
((), None),
((), ())]:
x = TensorType('float64', [(entry == 1) for entry in xsh])('x')
if dtype == "floatX":
dtype = theano.config.floatX
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
e = CAReduce(scalar_op, axis = tosum)(x)
if tosum is None: tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e])).make_function()
xv = numpy.asarray(numpy.random.rand(*xsh))
if dtype.startswith('float'):
xv = numpy.asarray(xv,dtype=dtype)
else:
xv = numpy.asarray(xv<0.5,dtype=dtype)
zv = xv
numpy_raised = False
if len(tosum)>1 and any([a<0 for a in tosum]):
......@@ -224,10 +230,17 @@ class test_CAReduce(unittest.TestCase):
numpy_raised=True
elif scalar_op == or_:
for axis in reversed(sorted(tosum)):
zv = numpy.any(zv, axis)
zv = numpy.bitwise_or.reduce(zv, axis)
elif scalar_op == and_:
for axis in reversed(sorted(tosum)):
zv = numpy.all(zv, axis)
zv = numpy.bitwise_and.reduce(zv, axis)
elif scalar_op == xor:
# There is no identity value for the xor function
# So we can't support shape of dimensions 0.
if numpy.prod(zv.shape)==0:
continue
for axis in reversed(sorted(tosum)):
zv = numpy.bitwise_xor.reduce(zv, axis)
else:
raise Exception("Test for CAReduce with scalar_op %s not implemented"%str(scalar_op))
if scalar_op in [maximum,minimum] and numpy_raised:
......@@ -238,13 +251,16 @@ class test_CAReduce(unittest.TestCase):
else:
self.fail()
else:
#numpy.{all,any} return bool type.
if scalar_op in [and_, or_]:
zv = numpy.asarray(zv, dtype=dtype)
self.assertTrue((numpy.abs(f(xv) - zv) < 1e-10).all())
#test CAReduce.infer_shape
#the Shape op don't implement c_code!
if isinstance(linker,gof.PerformLinker):
x = TensorType('float64', [(entry == 1) for entry in xsh])('x')
x = TensorType(dtype, [(entry == 1) for entry in xsh])('x')
e = CAReduce(scalar_op, axis = tosum)(x)
if tosum is None: tosum = range(len(xsh))
f = copy(linker).accept(Env([x], [e.shape])).make_function()
......@@ -256,20 +272,18 @@ class test_CAReduce(unittest.TestCase):
self.with_linker(gof.PerformLinker(), mul)
self.with_linker(gof.PerformLinker(), maximum)
self.with_linker(gof.PerformLinker(), minimum)
#need other dtype then real
#self.with_linker(gof.PerformLinker(), or_)
#self.with_linker(gof.PerformLinker(), and_)
self.with_linker(gof.PerformLinker(), or_, dtype='int8')
self.with_linker(gof.PerformLinker(), and_, dtype='int8')
self.with_linker(gof.PerformLinker(), xor, dtype='int8')
def test_c(self):
self.with_linker(gof.CLinker(), add)
self.with_linker(gof.CLinker(), mul)
self.with_linker(gof.CLinker(), maximum)
self.with_linker(gof.CLinker(), minimum)
#need other dtype then real
#no c_code for or_, and_
#self.with_linker(gof.CLinker(), or_)
#self.with_linker(gof.CLinker(), and_)
self.with_linker(gof.CLinker(), or_, dtype='int8')
self.with_linker(gof.CLinker(), and_, dtype='int8')
self.with_linker(gof.CLinker(), xor, dtype='int8')
class test_Prod(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论