提交 1a067453 authored 作者: Amjad Almahairi's avatar Amjad Almahairi

add more tests

上级 0a3e8fce
...@@ -4539,6 +4539,43 @@ class test_local_remove_switch_const_cond(unittest.TestCase): ...@@ -4539,6 +4539,43 @@ class test_local_remove_switch_const_cond(unittest.TestCase):
vy = numpy.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2) vy = numpy.array([[7, 8, 9], [10, 11, 12]], dtype=dtype2)
assert numpy.all(f(vx, vy) == vx) assert numpy.all(f(vx, vy) == vx)
def test_left_is_right(self):
for dtype1 in ['int32', 'int64']:
x = theano.tensor.matrix('x', dtype=dtype1)
z1 = theano.tensor.switch(1, x, x)
z0 = theano.tensor.switch(0, x, x)
f1 = theano.function([x], z1, mode=self.mode)
f0 = theano.function([x], z0, mode=self.mode)
topo = f1.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
topo = f0.maker.fgraph.toposort()
assert len(topo) == 1
assert topo[0].op == deep_copy_op
vx = numpy.array([[1, 2, 3], [4, 5, 6]], dtype=dtype1)
assert numpy.all(f1(vx) == vx)
assert numpy.all(f0(vx) == vx)
def test_shape_le_0(self):
for dtype1 in ['float32', 'float64']:
x = theano.tensor.matrix('x', dtype=dtype1)
z0 = theano.tensor.switch(theano.tensor.le(x.shape[0], 0), 0, x.shape[0])
f0 = theano.function([x], z0, mode=self.mode)
assert isinstance(f0.maker.fgraph.toposort()[0].op, Shape_i)
z1 = theano.tensor.switch(theano.tensor.le(x.shape[1], 0), 0, x.shape[1])
f1 = theano.function([x], z1, mode=self.mode)
assert isinstance(f1.maker.fgraph.toposort()[0].op, Shape_i)
vx = numpy.random.randn(0,5).astype(dtype1)
assert f0(vx) == 0
assert f1(vx) == 5
def test_broadcast1(self): def test_broadcast1(self):
# test switch(cst, matrix, row) # test switch(cst, matrix, row)
x = theano.tensor.matrix('x', dtype='int32') x = theano.tensor.matrix('x', dtype='int32')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论