提交 0c67f88f authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #6 from carriepl/fix_advincsubtensor_grad

Add test case for adv_inc_subtensor with broadcasting of the increment
......@@ -1328,20 +1328,24 @@ class TestAdvancedSubtensor(unittest.TestCase):
if inplace_increment is None:
raise inplace_increment_missing
a = inc_subtensor(self.m[self.ix1, self.ix12], 2.1)
inc = dscalar()
a = inc_subtensor(self.m[self.ix1, self.ix12], inc)
g_inc = tensor.grad(a.sum(), inc)
assert a.type == self.m.type, (a.type, self.m.type)
f = theano.function([self.m, self.ix1, self.ix12], a,
f = theano.function([self.m, self.ix1, self.ix12, inc], [a, g_inc],
allow_input_downcast=True)
aval = f([[.4, .9, .1],
[5, 6, 7],
[.5, .3, .15]],
[1, 2, 1],
[0, 1, 0])
aval, gval = f([[.4, .9, .1],
[5, 6, 7],
[.5, .3, .15]],
[1, 2, 1],
[0, 1, 0],
2.1)
assert numpy.allclose(aval,
[[.4, .9, .1],
[5 + 2.1 * 2, 6, 7],
[.5, .3 + 2.1, .15]]), aval
assert numpy.allclose(gval, 3.0), gval
def test_inc_adv_subtensor_with_index_broadcasting(self):
if inplace_increment is None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论