提交 47de01bd authored 作者: Ian Goodfellow's avatar Ian Goodfellow

RepeatOp.{grad,connection_pattern}

上级 52c1daf3
...@@ -5,6 +5,7 @@ import theano ...@@ -5,6 +5,7 @@ import theano
import basic import basic
from theano import gof, scalar from theano import gof, scalar
import basic as tensor import basic as tensor
from theano.gradient import DisconnectedType
class DiffOp(theano.Op): class DiffOp(theano.Op):
...@@ -258,6 +259,10 @@ class RepeatOp(theano.Op): ...@@ -258,6 +259,10 @@ class RepeatOp(theano.Op):
z = output_storage[0] z = output_storage[0]
z[0] = np.repeat(x, repeats=repeats, axis=self.axis) z[0] = np.repeat(x, repeats=repeats, axis=self.axis)
def connection_pattern(self, node):
return [ [True], [False] ]
def grad(self, (x, repeats), (gz, )): def grad(self, (x, repeats), (gz, )):
if repeats.ndim == 0: if repeats.ndim == 0:
if self.axis is None: if self.axis is None:
...@@ -271,7 +276,8 @@ class RepeatOp(theano.Op): ...@@ -271,7 +276,8 @@ class RepeatOp(theano.Op):
shape = [x.shape[k] for k in range(x.ndim)] shape = [x.shape[k] for k in range(x.ndim)]
shape.insert(axis, repeats) shape.insert(axis, repeats)
return [gz.reshape(shape, x.ndim + 1).sum(axis=axis), None] return [gz.reshape(shape, x.ndim + 1).sum(axis=axis),
DisconnectedType()() ]
elif repeats.ndim == 1: elif repeats.ndim == 1:
# For this implementation, we would need to specify the length # For this implementation, we would need to specify the length
# of repeats in order to split gz in the right way to sum # of repeats in order to split gz in the right way to sum
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论