提交 47de01bd authored 作者: Ian Goodfellow's avatar Ian Goodfellow

RepeatOp.{grad,connection_pattern}

上级 52c1daf3
......@@ -5,6 +5,7 @@ import theano
import basic
from theano import gof, scalar
import basic as tensor
from theano.gradient import DisconnectedType
class DiffOp(theano.Op):
......@@ -258,6 +259,10 @@ class RepeatOp(theano.Op):
z = output_storage[0]
z[0] = np.repeat(x, repeats=repeats, axis=self.axis)
def connection_pattern(self, node):
return [ [True], [False] ]
def grad(self, (x, repeats), (gz, )):
if repeats.ndim == 0:
if self.axis is None:
......@@ -271,7 +276,8 @@ class RepeatOp(theano.Op):
shape = [x.shape[k] for k in range(x.ndim)]
shape.insert(axis, repeats)
return [gz.reshape(shape, x.ndim + 1).sum(axis=axis), None]
return [gz.reshape(shape, x.ndim + 1).sum(axis=axis),
DisconnectedType()() ]
elif repeats.ndim == 1:
# For this implementation, we would need to specify the length
# of repeats in order to split gz in the right way to sum
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论