提交 e0ad0fd7 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Add grad implementation for repeat for scalar.

上级 822a6648
......@@ -273,7 +273,6 @@ class RepeatOp(theano.Op):
Keywords arguments:
axis -- int, optional.
"""
def __init__(self, axis=None):
......@@ -302,14 +301,24 @@ class RepeatOp(theano.Op):
z = output_storage[0]
z[0] = np.repeat(x, repeats=repeats, axis=self.axis)
def grad(self, inputs, outputs_gradients):
repeats = inputs[1]
out = outputs_gradients[0]
if inputs[0].ndim != 1:
raise NotImplementedError()
if repeats.ndim != 0:
def grad(self, (x, repeats), (gz, )):
if repeats.ndim == 0:
if self.axis is None:
axis = x.ndim
else:
if self.axis >= 0:
axis = self.axis + 1
else:
axis = self.axis + x.ndim
shape = [x.shape[k] for k in range(x.ndim)]
shape.insert(axis, repeats)
return [gz.reshape(shape, x.ndim + 1).sum(axis=axis), None]
elif repeats.ndim == 1:
raise NotImplementedError()
return [out.reshape([inputs[0].shape[0], repeats]).sum(axis=1), None]
else:
raise ValueError()
def infer_shape(self, node, ins_shapes):
i0_shapes = ins_shapes[0]
......
......@@ -179,14 +179,12 @@ class TestRepeatOp(utt.InferShapeTester):
self.op_class)
def test_grad(self):
x = T.dvector('x')
a = np.random.random(50)
gf = theano.function([x], T.grad(T.sum(repeat(x, 3)), x))
for ndim in range(3)[1:]:
x = T.TensorType('float64', [False] * ndim)
a = np.random.random((10, ) * ndim)
def repeat_(a):
return RepeatOp()(a, 3)
utt.verify_grad(repeat_, [a])
for axis in [None] + range(ndim):
utt.verify_grad(lambda x: RepeatOp(axis=axis)(x, 3), [a])
class TestBartlett(utt.InferShapeTester):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论