提交 0fe5fdb1 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2537 from carriepl/test_advincsubtensor1_grad

[CRASH] Fix crash with broadcasted increment in AdvancedIncSubtensor1 and added test
......@@ -1971,6 +1971,7 @@ class AdvancedIncSubtensor1(Op):
else:
gx = g_output
gy = advanced_subtensor1(g_output, idx_list)
gy = _sum_grad_over_bcasted_dims(y, gy)
return [gx, gy] + [DisconnectedType()()]
......
......@@ -1347,6 +1347,28 @@ class TestAdvancedSubtensor(unittest.TestCase):
[.5, .3 + 2.1, .15]]), aval
assert numpy.allclose(gval, 3.0), gval
def test_inc_adv_subtensor1_with_broadcasting(self):
if inplace_increment is None:
raise inplace_increment_missing
inc = dscalar()
a = inc_subtensor(self.m[self.ix1], inc)
g_inc = tensor.grad(a.sum(), inc)
assert a.type == self.m.type, (a.type, self.m.type)
f = theano.function([self.m, self.ix1, inc], [a, g_inc],
allow_input_downcast=True)
aval, gval = f([[.4, .9, .1],
[5, 6, 7],
[.5, .3, .15]],
[0, 1, 0],
2.1)
assert numpy.allclose(aval,
[[.4 + 2.1 * 2, .9 + 2.1 * 2, .1 + 2.1 * 2],
[5 + 2.1, 6 + 2.1, 7 + 2.1],
[.5, .3, .15]]), aval
assert numpy.allclose(gval, 9.0), gval
def test_inc_adv_subtensor_with_index_broadcasting(self):
if inplace_increment is None:
raise inplace_increment_missing
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论