提交 e67a5ba0 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2363 from lamblin/fix_advincsubtensor_grad

Fix advincsubtensor grad crash
...@@ -33,7 +33,7 @@ class NullType(Type): ...@@ -33,7 +33,7 @@ class NullType(Type):
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other)
def __hash__(self, other): def __hash__(self):
return hash(type(self)) return hash(type(self))
def __str__(self): def __str__(self):
......
...@@ -1544,32 +1544,42 @@ class IncSubtensor(Op): ...@@ -1544,32 +1544,42 @@ class IncSubtensor(Op):
else: else:
gx = g_output gx = g_output
gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list) gy = Subtensor(idx_list=self.idx_list)(g_output, *idx_list)
if gy.broadcastable != y.broadcastable: gy = _sum_grad_over_bcasted_dims(y, gy)
y_dim_added = gy.ndim - y.ndim
y_broad = (True,) * y_dim_added + y.broadcastable
assert sum(gy.broadcastable) < sum(y_broad)
axis_to_sum = []
for i in range(gy.ndim):
if gy.broadcastable[i] is False and y_broad[i] is True:
axis_to_sum.append(i)
elif (gy.broadcastable[i] is True and
y_broad[i] is False):
# This mean that Theano where able to infer that
# gy.shape[i] is 1, so y.shape[i] is 1, but we
# didn't know it. It is fine.
pass
else:
assert gy.broadcastable[i] == y_broad[i]
gy = gy.sum(axis=axis_to_sum, keepdims=True)
if gy.ndim != y.ndim:
assert gy.ndim > y.ndim
for i in range(y_dim_added):
assert gy.broadcastable[i]
gy = gy.dimshuffle(*range(y_dim_added, gy.ndim))
assert gy.broadcastable == y.broadcastable
return [gx, gy] + [DisconnectedType()()] * len(idx_list) return [gx, gy] + [DisconnectedType()()] * len(idx_list)
def _sum_grad_over_bcasted_dims(x, gx):
"""Sum of gx over dimensions to reproduce x.broadcastable.
This is useful to sum gradients over certain dimensions when
x has been broadcasted, and we need to sum the gradient contributions
over all duplications.
"""
if gx.broadcastable != x.broadcastable:
x_dim_added = gx.ndim - x.ndim
x_broad = (True,) * x_dim_added + x.broadcastable
assert sum(gx.broadcastable) < sum(x_broad)
axis_to_sum = []
for i in range(gx.ndim):
if gx.broadcastable[i] is False and x_broad[i] is True:
axis_to_sum.append(i)
elif (gx.broadcastable[i] is True and
x_broad[i] is False):
# This means that Theano was able to infer that
# gx.shape[i] is 1, so x.shape[i] is 1, but we
# didn't know it. It is fine.
pass
else:
assert gx.broadcastable[i] == x_broad[i]
gx = gx.sum(axis=axis_to_sum, keepdims=True)
if gx.ndim != x.ndim:
assert gx.ndim > x.ndim
for i in range(x_dim_added):
assert gx.broadcastable[i]
gx = gx.dimshuffle(*range(x_dim_added, gx.ndim))
assert gx.broadcastable == x.broadcastable
return gx
######################### #########################
# Advanced indexing # Advanced indexing
...@@ -2247,6 +2257,9 @@ class AdvancedIncSubtensor(Op): ...@@ -2247,6 +2257,9 @@ class AdvancedIncSubtensor(Op):
else: else:
gx = outgrad gx = outgrad
gy = advanced_subtensor(outgrad, *idxs) gy = advanced_subtensor(outgrad, *idxs)
# Make sure to sum gy over the dimensions of y that have been
# added or broadcasted
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + \ return [gx, gy] + \
[DisconnectedType()() for _ in idxs] [DisconnectedType()() for _ in idxs]
......
...@@ -1328,20 +1328,24 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -1328,20 +1328,24 @@ class TestAdvancedSubtensor(unittest.TestCase):
if inplace_increment is None: if inplace_increment is None:
raise inplace_increment_missing raise inplace_increment_missing
a = inc_subtensor(self.m[self.ix1, self.ix12], 2.1) inc = dscalar()
a = inc_subtensor(self.m[self.ix1, self.ix12], inc)
g_inc = tensor.grad(a.sum(), inc)
assert a.type == self.m.type, (a.type, self.m.type) assert a.type == self.m.type, (a.type, self.m.type)
f = theano.function([self.m, self.ix1, self.ix12], a, f = theano.function([self.m, self.ix1, self.ix12, inc], [a, g_inc],
allow_input_downcast=True) allow_input_downcast=True)
aval = f([[.4, .9, .1], aval, gval = f([[.4, .9, .1],
[5, 6, 7], [5, 6, 7],
[.5, .3, .15]], [.5, .3, .15]],
[1, 2, 1], [1, 2, 1],
[0, 1, 0]) [0, 1, 0],
2.1)
assert numpy.allclose(aval, assert numpy.allclose(aval,
[[.4, .9, .1], [[.4, .9, .1],
[5 + 2.1 * 2, 6, 7], [5 + 2.1 * 2, 6, 7],
[.5, .3 + 2.1, .15]]), aval [.5, .3 + 2.1, .15]]), aval
assert numpy.allclose(gval, 3.0), gval
def test_inc_adv_subtensor_with_index_broadcasting(self): def test_inc_adv_subtensor_with_index_broadcasting(self):
if inplace_increment is None: if inplace_increment is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论