提交 d87d44ce authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use IncSubtensor in gradient of RepeatOp

上级 63c513e1
...@@ -740,18 +740,15 @@ class Repeat(Op): ...@@ -740,18 +740,15 @@ class Repeat(Op):
(gz,) = gout (gz,) = gout
axis = self.axis axis = self.axis
# To sum the gradients that belong to the same repeated x, # Use IncSubtensor to sum the gradients that belong to the repeated entries of x
# We create a repeated eye and dot product it with the gradient.
axis_size = x.shape[axis] axis_size = x.shape[axis]
repeated_eye = repeat( repeated_arange = repeat(ptb.arange(axis_size), repeats, axis=0)
ptb.eye(axis_size), repeats, axis=0
) # A sparse repeat would be neat # Move the axis to repeat to front for easier indexing
x_transpose = ptb.moveaxis(x, axis, 0)
# Place gradient axis at end for dot product gz_transpose = ptb.moveaxis(gz, axis, 0)
gx = ptb.moveaxis(gz, axis, -1) gx_transpose = ptb.zeros_like(x_transpose)[repeated_arange].inc(gz_transpose)
gx = gx @ repeated_eye gx = ptb.moveaxis(gx_transpose, 0, axis)
# Place gradient back into the correct axis
gx = ptb.moveaxis(gx, -1, axis)
return [gx, DisconnectedType()()] return [gx, DisconnectedType()()]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论