提交 fa8c474c authored 作者: Frederic's avatar Frederic

Allow to test GPUADimshuffle with float16

上级 851e3dd1
......@@ -27,6 +27,7 @@ def FunctionGraph(i, o):
class test_DimShuffle(unittest_tools.InferShapeTester):
op = DimShuffle
type = TensorType
dtype = theano.config.floatX
def with_linker(self, linker):
for xsh, shuffle, zsh in [((2, 3), (1, 'x', 0), (3, 1, 2)),
......@@ -40,25 +41,25 @@ class test_DimShuffle(unittest_tools.InferShapeTester):
((1, 1, 1), (), ()),
((1,), ('x', 'x'), (1, 1))]:
ib = [(entry == 1) for entry in xsh]
x = self.type('float64', ib)('x')
x = self.type(self.dtype, ib)('x')
e = self.op(ib, shuffle)(x)
f = copy(linker).accept(FunctionGraph([x], [e])).make_function()
assert f(numpy.ones(xsh)).shape == zsh
assert f(numpy.ones(xsh, dtype=self.dtype)).shape == zsh
# test that DimShuffle.infer_shape work correctly
x = self.type('float64', ib)('x')
x = self.type(self.dtype, ib)('x')
e = self.op(ib, shuffle)(x)
f = copy(linker).accept(FunctionGraph([x],
[e.shape])).make_function()
assert all(f(numpy.ones(xsh))) == all(zsh)
assert all(f(numpy.ones(xsh, dtype=self.dtype))) == all(zsh)
# Test when we drop a axis that is not broadcastable
ib = [False, True, False]
x = self.type('float64', ib)('x')
x = self.type(self.dtype, ib)('x')
self.assertRaises(ValueError, self.op, ib, shuffle)
# Test when we drop a axis that don't have shape 1
ib = [True, True, False]
x = self.type('float64', ib)('x')
x = self.type(self.dtype, ib)('x')
e = self.op(ib, (1, 2))(x)
f = copy(linker).accept(FunctionGraph([x], [e.shape])).make_function()
self.assertRaises(TypeError, f, numpy.ones((2, 1, 4)))
......@@ -66,7 +67,7 @@ class test_DimShuffle(unittest_tools.InferShapeTester):
# Test that we can't take a dimensions multiple time
xsh, shuffle, zsh = ((1, 1, 4), (0, 1, 2, 0), (1, 4))
ib = [False, True, False]
x = self.type('float64', ib)('x')
x = self.type(self.dtype, ib)('x')
self.assertRaises(ValueError, DimShuffle, ib, shuffle)
def test_perform(self):
......@@ -89,15 +90,15 @@ class test_DimShuffle(unittest_tools.InferShapeTester):
((1, 1, 1), ()),
((1,), ('x', 'x'))]:
ib = [(entry == 1) for entry in xsh]
adtens = self.type('float64', ib)('x')
adtens_val = numpy.ones(xsh)
adtens = self.type(self.dtype, ib)('x')
adtens_val = numpy.ones(xsh, dtype=self.dtype)
self._compile_and_check([adtens],
[self.op(ib, shuffle)(adtens)],
[adtens_val], self.op,
warn=False)
def test_too_big_rank(self):
x = self.type('float64', broadcastable=())()
x = self.type(self.dtype, broadcastable=())()
y = x.dimshuffle(('x',) * (numpy.MAXDIMS + 1))
self.assertRaises(ValueError, y.eval, {x: 0})
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论