提交 01be2efc authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Add more tests.

上级 ba483b3f
......@@ -327,13 +327,16 @@ class RepeatOp(theano.Op):
out_shape = list(i0_shapes)
if self.axis == None:
if len(i0_shapes) == 0:
out_shape = [repeats]
if repeats.ndim == 0:
if len(i0_shapes) == 0:
out_shape = [repeats]
else:
res = 1
for d in i0_shapes:
res = res * d
out_shape = (res * repeats, )
else:
res = 1
for d in i0_shapes:
res = res * d
out_shape = (res * repeats, )
out_shape = [theano.tensor.sum(repeats)]
else:
if repeats.ndim == 0:
out_shape[self.axis] = out_shape[self.axis] * repeats
......
......@@ -137,6 +137,9 @@ class TestSqueezeOp(utt.InferShapeTester):
class TestRepeatOp(utt.InferShapeTester):
def _possible_axis(self, ndim):
return [None] + range(ndim) + [-i for i in range(ndim)]
def setUp(self):
super(TestRepeatOp, self).setUp()
self.op_class = RepeatOp
......@@ -147,37 +150,40 @@ class TestRepeatOp(utt.InferShapeTester):
x = T.TensorType(theano.config.floatX, [False] * ndim)()
a = np.random.random((10, ) * ndim)
r_var = T.lscalar()
r = 3
for axis in [None] + range(ndim):
for axis in self._possible_axis(ndim):
r_var = T.lscalar()
r = 3
f = theano.function([x, r_var], repeat(x, r_var, axis=axis))
assert np.allclose(np.repeat(a, r, axis=axis), f(a, r))
r_var = T.lvector()
r = np.random.random_integers(5, size=(10,))
r_var = T.lvector()
if axis is None:
r = np.random.random_integers(5, size=a.size)
else:
r = np.random.random_integers(5, size=(10,))
for axis in range(ndim):
f = theano.function([x, r_var], repeat(x, r_var, axis=axis))
assert np.allclose(np.repeat(a, r, axis=axis), f(a, r))
def test_infer_shape(self):
for ndim in range(4):
x = T.TensorType(theano.config.floatX, [False] * ndim)()
a = np.random.random((10, ) * ndim)
r_var = T.lscalar()
r = 3
for axis in [None] + range(ndim):
for axis in self._possible_axis(ndim):
r_var = T.lscalar()
r = 3
self._compile_and_check([x, r_var],
[RepeatOp(axis=axis)(x, r_var)],
[a, r],
self.op_class)
r_var = T.lvector()
r = np.random.random_integers(5, size=(10,))
r_var = T.lvector()
if axis is None:
r = np.random.random_integers(5, size=a.size)
else:
r = np.random.random_integers(5, size=(10,))
for axis in range(ndim):
self._compile_and_check([x, r_var],
[RepeatOp(axis=axis)(x, r_var)],
[a, r],
......@@ -187,10 +193,8 @@ class TestRepeatOp(utt.InferShapeTester):
for ndim in range(3):
a = np.random.random((10, ) * ndim)
for axis in [None] + range(ndim):
for axis in self._possible_axis(ndim):
utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a])
if ndim > 0:
utt.verify_grad(lambda x: RepeatOp(axis=-1)(x, 3), [a])
class TestBartlett(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论