提交 e0ad0fd7 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Add grad implementation for repeat for scalar.

上级 822a6648
...@@ -273,7 +273,6 @@ class RepeatOp(theano.Op): ...@@ -273,7 +273,6 @@ class RepeatOp(theano.Op):
Keywords arguments: Keywords arguments:
axis -- int, optional. axis -- int, optional.
""" """
def __init__(self, axis=None): def __init__(self, axis=None):
...@@ -302,14 +301,24 @@ class RepeatOp(theano.Op): ...@@ -302,14 +301,24 @@ class RepeatOp(theano.Op):
z = output_storage[0] z = output_storage[0]
z[0] = np.repeat(x, repeats=repeats, axis=self.axis) z[0] = np.repeat(x, repeats=repeats, axis=self.axis)
def grad(self, inputs, outputs_gradients): def grad(self, (x, repeats), (gz, )):
repeats = inputs[1] if repeats.ndim == 0:
out = outputs_gradients[0] if self.axis is None:
if inputs[0].ndim != 1: axis = x.ndim
raise NotImplementedError() else:
if repeats.ndim != 0: if self.axis >= 0:
axis = self.axis + 1
else:
axis = self.axis + x.ndim
shape = [x.shape[k] for k in range(x.ndim)]
shape.insert(axis, repeats)
return [gz.reshape(shape, x.ndim + 1).sum(axis=axis), None]
elif repeats.ndim == 1:
raise NotImplementedError() raise NotImplementedError()
return [out.reshape([inputs[0].shape[0], repeats]).sum(axis=1), None] else:
raise ValueError()
def infer_shape(self, node, ins_shapes): def infer_shape(self, node, ins_shapes):
i0_shapes = ins_shapes[0] i0_shapes = ins_shapes[0]
......
...@@ -179,14 +179,12 @@ class TestRepeatOp(utt.InferShapeTester): ...@@ -179,14 +179,12 @@ class TestRepeatOp(utt.InferShapeTester):
self.op_class) self.op_class)
def test_grad(self): def test_grad(self):
x = T.dvector('x') for ndim in range(3)[1:]:
a = np.random.random(50) x = T.TensorType('float64', [False] * ndim)
a = np.random.random((10, ) * ndim)
gf = theano.function([x], T.grad(T.sum(repeat(x, 3)), x))
def repeat_(a): for axis in [None] + range(ndim):
return RepeatOp()(a, 3) utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a])
utt.verify_grad(repeat_, [a])
class TestBartlett(utt.InferShapeTester): class TestBartlett(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论