提交 f815e119 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Fix bug in grad for axis=-1.

上级 e0ad0fd7
...@@ -309,7 +309,7 @@ class RepeatOp(theano.Op): ...@@ -309,7 +309,7 @@ class RepeatOp(theano.Op):
if self.axis >= 0: if self.axis >= 0:
axis = self.axis + 1 axis = self.axis + 1
else: else:
axis = self.axis + x.ndim axis = self.axis + x.ndim + 1
shape = [x.shape[k] for k in range(x.ndim)] shape = [x.shape[k] for k in range(x.ndim)]
shape.insert(axis, repeats) shape.insert(axis, repeats)
......
...@@ -180,11 +180,11 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -180,11 +180,11 @@ class TestRepeatOp(utt.InferShapeTester):
def test_grad(self): def test_grad(self):
for ndim in range(3)[1:]: for ndim in range(3)[1:]:
x = T.TensorType('float64', [False] * ndim)
a = np.random.random((10, ) * ndim) a = np.random.random((10, ) * ndim)
for axis in [None] + range(ndim): for axis in [None] + range(ndim):
utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a]) utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a])
utt.verify_grad(lambda x: RepeatOp(axis=-1)(x, 3), [a])
class TestBartlett(utt.InferShapeTester): class TestBartlett(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论