提交 547255d5 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

fix tests

上级 c44b8b5c
......@@ -3228,6 +3228,63 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
f = theano.function([x], T.ge(x, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 1)
def assert_identity(self, f):
topo = f.maker.fgraph.toposort()
assert topo[0].op == deep_copy_op
x_val = -128
assert f(x_val) == x_val
x_val = -1
assert f(x_val) == x_val
x_val = 0
assert f(x_val) == x_val
x_val = 1
assert f(x_val) == x_val
x_val = 127
assert f(x_val) == x_val
x_val = numpy.random.randint(255)-128
assert f(x_val) == x_val
def test_and(self):
mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.and_(x, 0), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0)
f = theano.function([x], T.and_(0, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0)
f = theano.function([x], T.and_(x, 1), mode=mode)
self.assert_identity(f)
f = theano.function([x], T.and_(1, x), mode=mode)
self.assert_identity(f)
def test_or(self):
mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.or_(x, 1), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 1)
f = theano.function([x], T.or_(1, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 1)
f = theano.function([x], T.or_(x, 0), mode=mode)
self.assert_identity(f)
f = theano.function([x], T.or_(0, x), mode=mode)
self.assert_identity(f)
def test_xor(self):
mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.xor(x, x), mode=mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0)
class Test_local_useless_alloc(unittest.TestCase):
def setUp(self):
......@@ -4001,58 +4058,6 @@ class T_useless_elemwise(unittest.TestCase):
assert len(topo) == 1
assert topo[0].op == deep_copy_op
def assert_identity(self, f):
topo = f.maker.fgraph.toposort()
assert topo[0].op == deep_copy_op
x_val = -128
assert f(x_val) == x_val
x_val = -1
assert f(x_val) == x_val
x_val = 0
assert f(x_val) == x_val
x_val = 1
assert f(x_val) == x_val
x_val = 127
assert f(x_val) == x_val
x_val = numpy.random.randint(255)-128
assert f(x_val) == x_val
def test_and(self):
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.and_(x, 0), mode=self.mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0)
f = theano.function([x], T.and_(0, x), mode=self.mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0)
f = theano.function([x], T.and_(x, 1), mode=self.mode)
self.assert_identity(f)
f = theano.function([x], T.and_(1, x), mode=self.mode)
self.assert_identity(f)
def test_or(self):
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.or_(x, 1), mode=self.mode)
assert_eqs_const(f.maker.fgraph.toposort(), 1)
f = theano.function([x], T.or_(1, x), mode=self.mode)
assert_eqs_const(f.maker.fgraph.toposort(), 1)
f = theano.function([x], T.or_(x, 0), mode=self.mode)
self.assert_identity(f)
f = theano.function([x], T.or_(0, x), mode=self.mode)
self.assert_identity(f)
def test_xor(self):
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.xor(x, x), mode=self.mode)
assert_eqs_const(f.maker.fgraph.toposort(), 0)
class T_cast_cast(unittest.TestCase):
def setUp(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论