提交 ba483b3f authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Fix bug in RepeatOp.infer_shape and redo tests.

上级 63b0d7ea
......@@ -327,10 +327,13 @@ class RepeatOp(theano.Op):
out_shape = list(i0_shapes)
if self.axis == None:
res = 0
for d in i0_shapes:
res = res + d
out_shape = (res * repeats, )
if len(i0_shapes) == 0:
out_shape = [repeats]
else:
res = 1
for d in i0_shapes:
res = res * d
out_shape = (res * repeats, )
else:
if repeats.ndim == 0:
out_shape[self.axis] = out_shape[self.axis] * repeats
......
......@@ -137,54 +137,60 @@ class TestSqueezeOp(utt.InferShapeTester):
class TestRepeatOp(utt.InferShapeTester):
nb = 5
def setUp(self):
super(TestRepeatOp, self).setUp()
self.op_class = RepeatOp
self.op = RepeatOp()
def test_repeatOp(self):
x = T.dmatrix('x')
a = np.random.random((30, 50))
for ndim in range(3):
x = T.TensorType(theano.config.floatX, [False] * ndim)()
a = np.random.random((10, ) * ndim)
r_var = T.lscalar()
r = 3
for axis in [None] + range(ndim):
f = theano.function([x, r_var], repeat(x, r_var, axis=axis))
assert np.allclose(np.repeat(a, r, axis=axis), f(a, r))
r_var = T.lvector()
r = np.random.random_integers(5, size=(10,))
for axis in range(ndim):
f = theano.function([x, r_var], repeat(x, r_var, axis=axis))
assert np.allclose(np.repeat(a, r, axis=axis), f(a, r))
for axis in [None] + range(len(a.shape)):
for repeats in range(TestRepeatOp.nb):
f = theano.function([x], repeat(x, repeats, axis=axis))
assert np.allclose(np.repeat(a, repeats, axis=axis), f(a))
def test_infer_shape(self):
x = T.dvector('x')
m = T.iscalars('m')
a = np.random.random(50)
for ndim in range(4):
x = T.TensorType(theano.config.floatX, [False] * ndim)()
a = np.random.random((10, ) * ndim)
self._compile_and_check([x, m],
[repeat(x, m)],
[a, 2],
self.op_class)
r_var = T.lscalar()
r = 3
for axis in [None] + range(ndim):
self._compile_and_check([x, r_var],
[RepeatOp(axis=axis)(x, r_var)],
[a, r],
self.op_class)
x = T.dmatrix('x')
a = np.random.random((40, 50))
for axis in range(len(a.shape)):
self._compile_and_check([x, m],
[repeat(x, m, axis=axis)],
[a, 2],
self.op_class)
m = T.lvector('m')
repeats = np.random.random_integers(5, size=(40, ))
self._compile_and_check([x, m],
[repeat(x, m, axis=0)],
[a, repeats],
self.op_class)
r_var = T.lvector()
r = np.random.random_integers(5, size=(10,))
for axis in range(ndim):
self._compile_and_check([x, r_var],
[RepeatOp(axis=axis)(x, r_var)],
[a, r],
self.op_class)
def test_grad(self):
for ndim in range(3)[1:]:
for ndim in range(3):
a = np.random.random((10, ) * ndim)
for axis in [None] + range(ndim):
utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a])
utt.verify_grad(lambda x: RepeatOp(axis=-1)(x, 3), [a])
if ndim > 0:
utt.verify_grad(lambda x: RepeatOp(axis=-1)(x, 3), [a])
class TestBartlett(utt.InferShapeTester):
......@@ -276,6 +282,7 @@ class TestFillDiagonal(utt.InferShapeTester):
self.op_class)
if __name__ == "__main__":
utt.unittest.main()
t = TestFillDiagonal('setUp')
t.setUp()
t.test_perform()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论