提交 9e4c0e48 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Thomas Wiecki

Simplify asserts

上级 4116a35d
......@@ -1667,18 +1667,15 @@ class TestJoinAndSplit:
a = self.shared(a_val, shape=(None, None, 1))
b = self.shared(b_val, shape=(1, None, 1))
c = self.join_op(1, a, b)
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1
assert c.type.shape == (1, None, 1)
# Opt can remplace the int by an PyTensor constant
c = self.join_op(constant(1), a, b)
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1
assert c.type.shape == (1, None, 1)
# In case futur opt insert other useless stuff
c = self.join_op(cast(constant(1), dtype="int32"), a, b)
assert c.type.shape[0] == 1 and c.type.shape[2] == 1
assert c.type.shape[1] != 1
assert c.type.shape == (1, None, 1)
f = function([], c, mode=self.mode)
topo = f.maker.fgraph.toposort()
......@@ -1783,15 +1780,21 @@ class TestJoinAndSplit:
c = TensorType(dtype=self.floatX, shape=(1, None, None, None, None, None))()
d = TensorType(dtype=self.floatX, shape=(1, None, 1, 1, None, 1))()
e = TensorType(dtype=self.floatX, shape=(1, None, 1, None, None, 1))()
f = self.join_op(0, a, b, c, d, e)
fb = tuple(s == 1 for s in f.type.shape)
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5]
assert f.type.shape == (5, 1, 1, 1, None, 1)
assert fb == (False, True, True, True, False, True)
g = self.join_op(1, a, b, c, d, e)
gb = tuple(s == 1 for s in g.type.shape)
assert gb[0] and not gb[1] and gb[2] and gb[3] and not gb[4] and gb[5]
assert g.type.shape == (1, None, 1, 1, None, 1)
assert gb == (True, False, True, True, False, True)
h = self.join_op(4, a, b, c, d, e)
hb = tuple(s == 1 for s in h.type.shape)
assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5]
assert h.type.shape == (1, 1, 1, 1, None, 1)
assert hb == (True, True, True, True, False, True)
f = function([a, b, c, d, e], f, mode=self.mode)
topo = f.maker.fgraph.toposort()
......@@ -1903,7 +1906,7 @@ class TestJoinAndSplit:
rng = np.random.default_rng(seed=utt.fetch_seed())
v = self.shared(rng.random(4).astype(self.floatX))
m = self.shared(rng.random((4, 4)).astype(self.floatX))
with pytest.raises(TypeError):
with pytest.raises(TypeError, match="same number of dimensions"):
self.join_op(0, v, m)
def test_split_0elem(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论