提交 01092d49 authored 作者: James Bergstra's avatar James Bergstra

one more transpose test case

上级 4e64abe0
...@@ -5293,9 +5293,11 @@ def test_transpose(): ...@@ -5293,9 +5293,11 @@ def test_transpose():
x3.transpose(), x3.transpose(),
x2.transpose(0, 1), x2.transpose(0, 1),
x3.transpose((0, 2, 1)), x3.transpose((0, 2, 1)),
tensor.transpose(x2, [0, 1]),
tensor.transpose(x3, [0, 2, 1]),
]) ])
t1, t2, t3, t1b, t2b, t3b, t2c, t3c = f(x1v, x2v, x3v) t1, t2, t3, t1b, t2b, t3b, t2c, t3c, t2d, t3d = f(x1v, x2v, x3v)
assert t1.shape == numpy.transpose(x1v).shape assert t1.shape == numpy.transpose(x1v).shape
assert t2.shape == numpy.transpose(x2v).shape assert t2.shape == numpy.transpose(x2v).shape
assert t3.shape == numpy.transpose(x3v).shape assert t3.shape == numpy.transpose(x3v).shape
...@@ -5309,6 +5311,10 @@ def test_transpose(): ...@@ -5309,6 +5311,10 @@ def test_transpose():
assert t3c.shape == (2, 4, 3) assert t3c.shape == (2, 4, 3)
assert numpy.all(t2c == x2v.transpose([0, 1])) assert numpy.all(t2c == x2v.transpose([0, 1]))
assert numpy.all(t3c == x3v.transpose([0, 2, 1])) assert numpy.all(t3c == x3v.transpose([0, 2, 1]))
assert t2d.shape == (2, 12)
assert t3d.shape == (2, 4, 3)
assert numpy.all(t2d == numpy.transpose(x2v, [0, 1]))
assert numpy.all(t3d == numpy.transpose(x3v, [0, 2, 1]))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论