提交 648ce3da authored 作者: Frederic Bastien's avatar Frederic Bastien

Add test for DimShuffle of Tensor.

上级 f0b7dd12
......@@ -37,6 +37,24 @@ class test_DimShuffle(unittest.TestCase):
f = copy(linker).accept(Env([x], [e.shape])).make_function()
assert all(f(numpy.ones(xsh))) == all(zsh)
# Test when we drop a axis that is not broadcastable
ib = [False, True, False]
x = TensorType('float64', ib)('x')
self.assertRaises(ValueError, DimShuffle, ib, shuffle)
# Test when we drop a axis that don't have shape 1
ib = [True, True, False]
x = TensorType('float64', ib)('x')
e = DimShuffle(ib, (1, 2))(x)
f = copy(linker).accept(Env([x], [e.shape])).make_function()
self.assertRaises(TypeError, f, numpy.ones((2, 1, 4)))
# Test that we can't take a dimensions multiple time
xsh, shuffle, zsh = ((1, 1, 4), (0, 1, 2, 0), (1, 4))
ib = [False, True, False]
x = TensorType('float64', ib)('x')
self.assertRaises(ValueError, DimShuffle, ib, shuffle)
def test_perform(self):
self.with_linker(gof.PerformLinker())
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论