提交 6b9267e8 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Sum gradient over dimensions where the increment was broadcasted

上级 a4cb2200
......@@ -1543,32 +1543,42 @@ class IncSubtensor(Op):
else:
gx = g_output
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
if gy.broadcastable != y.broadcastable:
y_dim_added = gy.ndim - y.ndim
y_broad = (True,) * y_dim_added + y.broadcastable
assert sum(gy.broadcastable) < sum(y_broad)
axis_to_sum = []
for i in range(gy.ndim):
if gy.broadcastable[i] is False and y_broad[i] is True:
axis_to_sum.append(i)
elif (gy.broadcastable[i] is True and
y_broad[i] is False):
# This mean that Theano where able to infer that
# gy.shape[i] is 1, so y.shape[i] is 1, but we
# didn't know it. It is fine.
pass
else:
assert gy.broadcastable[i] == y_broad[i]
gy = gy.sum(axis=axis_to_sum, keepdims=True)
if gy.ndim != y.ndim:
assert gy.ndim > y.ndim
for i in range(y_dim_added):
assert gy.broadcastable[i]
gy = gy.dimshuffle(*range(y_dim_added, gy.ndim))
assert gy.broadcastable == y.broadcastable
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + [DisconnectedType()()] * len(idx_list)
def _sum_grad_over_bcasted_dims(x, gx):
"""Sum of gx over dimensions to reproduce x.broadcastable.
This is useful to sum gradients over certain dimensions when
x has been broadcasted, and we need to sum the gradient contributions
over all duplications.
"""
if gx.broadcastable != x.broadcastable:
x_dim_added = gx.ndim - x.ndim
x_broad = (True,) * x_dim_added + x.broadcastable
assert sum(gx.broadcastable) < sum(x_broad)
axis_to_sum = []
for i in range(gx.ndim):
if gx.broadcastable[i] is False and x_broad[i] is True:
axis_to_sum.append(i)
elif (gx.broadcastable[i] is True and
x_broad[i] is False):
# This means that Theano was able to infer that
# gx.shape[i] is 1, so x.shape[i] is 1, but we
# didn't know it. It is fine.
pass
else:
assert gx.broadcastable[i] == x_broad[i]
gx = gx.sum(axis=axis_to_sum, keepdims=True)
if gx.ndim != x.ndim:
assert gx.ndim > x.ndim
for i in range(x_dim_added):
assert gx.broadcastable[i]
gx = gx.dimshuffle(*range(x_dim_added, gx.ndim))
assert gx.broadcastable == x.broadcastable
return gx
#########################
# Advanced indexing
......@@ -2183,6 +2193,9 @@ class AdvancedIncSubtensor(Op):
else:
gx = outgrad
gy = advanced_subtensor(outgrad, *idxs)
# Make sure to sum gy over the dimensions of y that have been
# added or broadcasted
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + \
[DisconnectedType()() for _ in idxs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论