提交 10855815 authored 作者: Dustin Webb's avatar Dustin Webb 提交者: Amjad Almahairi

Added more test cases for logical and and or.

上级 83d3075f
......@@ -3932,29 +3932,47 @@ class T_useless_elemwise(unittest.TestCase):
assert isinstance(elem.inputs[0], T.TensorConstant)
assert T.extract_constant(elem.inputs[0]) == val
def assert_identity(self, f):
topo = f.maker.fgraph.toposort()
assert topo[0].op == deep_copy_op
x_val = numpy.random.randint(256)
assert f(x_val) == x_val
def test_and(self):
x = T.scalar('x', dtype='int64')
func = theano.function([x], T.and_(x, 0), mode=self.mode)
self.assert_eqs_const(func.maker.fgraph.toposort(), 0)
f = theano.function([x], T.and_(x, 0), mode=self.mode)
self.assert_eqs_const(f.maker.fgraph.toposort(), 0)
f = theano.function([x], T.and_(0, x), mode=self.mode)
self.assert_eqs_const(f.maker.fgraph.toposort(), 0)
func = theano.function([x], T.and_(0, x), mode=self.mode)
self.assert_eqs_const(func.maker.fgraph.toposort(), 0)
f = theano.function([x], T.and_(x, 1), mode=self.mode)
self.assert_identity(f)
f = theano.function([x], T.and_(1, x), mode=self.mode)
self.assert_identity(f)
def test_or(self):
x = T.scalar('x', dtype='int64')
func = theano.function([x], T.or_(x, 1), mode=self.mode)
self.assert_eqs_const(func.maker.fgraph.toposort(), 1)
f = theano.function([x], T.or_(x, 1), mode=self.mode)
self.assert_eqs_const(f.maker.fgraph.toposort(), 1)
f = theano.function([x], T.or_(1, x), mode=self.mode)
self.assert_eqs_const(f.maker.fgraph.toposort(), 1)
f = theano.function([x], T.or_(x, 0), mode=self.mode)
self.assert_identity(f)
func = theano.function([x], T.or_(1, x), mode=self.mode)
self.assert_eqs_const(func.maker.fgraph.toposort(), 1)
f = theano.function([x], T.or_(0, x), mode=self.mode)
self.assert_identity(f)
def test_xor(self):
x = T.scalar('x', dtype='int64')
func = theano.function([x], T.xor(x, x), mode=self.mode)
self.assert_eqs_const(func.maker.fgraph.toposort(), 0)
f = theano.function([x], T.xor(x, x), mode=self.mode)
self.assert_eqs_const(f.maker.fgraph.toposort(), 0)
class T_cast_cast(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论