提交 ba483b3f authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Fix bug in RepeatOp.infer_shape and redo tests.

上级 63b0d7ea
...@@ -327,10 +327,13 @@ class RepeatOp(theano.Op): ...@@ -327,10 +327,13 @@ class RepeatOp(theano.Op):
out_shape = list(i0_shapes) out_shape = list(i0_shapes)
if self.axis == None: if self.axis == None:
res = 0 if len(i0_shapes) == 0:
for d in i0_shapes: out_shape = [repeats]
res = res + d else:
out_shape = (res * repeats, ) res = 1
for d in i0_shapes:
res = res * d
out_shape = (res * repeats, )
else: else:
if repeats.ndim == 0: if repeats.ndim == 0:
out_shape[self.axis] = out_shape[self.axis] * repeats out_shape[self.axis] = out_shape[self.axis] * repeats
......
...@@ -137,54 +137,60 @@ class TestSqueezeOp(utt.InferShapeTester): ...@@ -137,54 +137,60 @@ class TestSqueezeOp(utt.InferShapeTester):
class TestRepeatOp(utt.InferShapeTester): class TestRepeatOp(utt.InferShapeTester):
nb = 5
def setUp(self): def setUp(self):
super(TestRepeatOp, self).setUp() super(TestRepeatOp, self).setUp()
self.op_class = RepeatOp self.op_class = RepeatOp
self.op = RepeatOp() self.op = RepeatOp()
def test_repeatOp(self): def test_repeatOp(self):
x = T.dmatrix('x') for ndim in range(3):
a = np.random.random((30, 50)) x = T.TensorType(theano.config.floatX, [False] * ndim)()
a = np.random.random((10, ) * ndim)
r_var = T.lscalar()
r = 3
for axis in [None] + range(ndim):
f = theano.function([x, r_var], repeat(x, r_var, axis=axis))
assert np.allclose(np.repeat(a, r, axis=axis), f(a, r))
r_var = T.lvector()
r = np.random.random_integers(5, size=(10,))
for axis in range(ndim):
f = theano.function([x, r_var], repeat(x, r_var, axis=axis))
assert np.allclose(np.repeat(a, r, axis=axis), f(a, r))
for axis in [None] + range(len(a.shape)):
for repeats in range(TestRepeatOp.nb):
f = theano.function([x], repeat(x, repeats, axis=axis))
assert np.allclose(np.repeat(a, repeats, axis=axis), f(a))
def test_infer_shape(self): def test_infer_shape(self):
x = T.dvector('x') for ndim in range(4):
m = T.iscalars('m') x = T.TensorType(theano.config.floatX, [False] * ndim)()
a = np.random.random(50) a = np.random.random((10, ) * ndim)
self._compile_and_check([x, m], r_var = T.lscalar()
[repeat(x, m)], r = 3
[a, 2], for axis in [None] + range(ndim):
self.op_class) self._compile_and_check([x, r_var],
[RepeatOp(axis=axis)(x, r_var)],
[a, r],
self.op_class)
x = T.dmatrix('x') r_var = T.lvector()
a = np.random.random((40, 50)) r = np.random.random_integers(5, size=(10,))
for axis in range(len(a.shape)):
self._compile_and_check([x, m], for axis in range(ndim):
[repeat(x, m, axis=axis)], self._compile_and_check([x, r_var],
[a, 2], [RepeatOp(axis=axis)(x, r_var)],
self.op_class) [a, r],
self.op_class)
m = T.lvector('m')
repeats = np.random.random_integers(5, size=(40, ))
self._compile_and_check([x, m],
[repeat(x, m, axis=0)],
[a, repeats],
self.op_class)
def test_grad(self): def test_grad(self):
for ndim in range(3)[1:]: for ndim in range(3):
a = np.random.random((10, ) * ndim) a = np.random.random((10, ) * ndim)
for axis in [None] + range(ndim): for axis in [None] + range(ndim):
utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a]) utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a])
utt.verify_grad(lambda x: RepeatOp(axis=-1)(x, 3), [a]) if ndim > 0:
utt.verify_grad(lambda x: RepeatOp(axis=-1)(x, 3), [a])
class TestBartlett(utt.InferShapeTester): class TestBartlett(utt.InferShapeTester):
...@@ -276,6 +282,7 @@ class TestFillDiagonal(utt.InferShapeTester): ...@@ -276,6 +282,7 @@ class TestFillDiagonal(utt.InferShapeTester):
self.op_class) self.op_class)
if __name__ == "__main__": if __name__ == "__main__":
utt.unittest.main()
t = TestFillDiagonal('setUp') t = TestFillDiagonal('setUp')
t.setUp() t.setUp()
t.test_perform() t.test_perform()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论