提交 d87d44ce authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use IncSubtensor in gradient of RepeatOp

上级 63c513e1
......@@ -740,18 +740,15 @@ class Repeat(Op):
(gz,) = gout
axis = self.axis
# To sum the gradients that belong to the same repeated x,
# We create a repeated eye and dot product it with the gradient.
# Use IncSubtensor to sum the gradients that belong to the repeated entries of x
axis_size = x.shape[axis]
repeated_eye = repeat(
ptb.eye(axis_size), repeats, axis=0
) # A sparse repeat would be neat
# Place gradient axis at end for dot product
gx = ptb.moveaxis(gz, axis, -1)
gx = gx @ repeated_eye
# Place gradient back into the correct axis
gx = ptb.moveaxis(gx, -1, axis)
repeated_arange = repeat(ptb.arange(axis_size), repeats, axis=0)
# Move the axis to repeat to front for easier indexing
x_transpose = ptb.moveaxis(x, axis, 0)
gz_transpose = ptb.moveaxis(gz, axis, 0)
gx_transpose = ptb.zeros_like(x_transpose)[repeated_arange].inc(gz_transpose)
gx = ptb.moveaxis(gx_transpose, 0, axis)
return [gx, DisconnectedType()()]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论