提交 c6fc9734 authored 作者: Razvan Pascanu's avatar Razvan Pascanu

Merge pull request #314 from bayerj/feature-AdvancedIncSubtensor-grad

Feature: advanced inc subtensor grad
...@@ -5139,10 +5139,14 @@ class AdvancedIncSubtensor(Op): ...@@ -5139,10 +5139,14 @@ class AdvancedIncSubtensor(Op):
out[0] = inputs[0].copy() out[0] = inputs[0].copy()
out[0][inputs[2:]] += inputs[1] out[0][inputs[2:]] += inputs[1]
#def grad? def grad(self, inpt, output_gradients):
# grad on x is grad on output x, y = inpt[:2]
# grad on y is grad_output[idx_list] idxs = inpt[2:]
# grad on rest is None outgrad, = output_gradients
d_x_wrt_C = outgrad
d_y_wrt_C = AdvancedSubtensor(self.args)(outgrad, *idxs)
return [d_x_wrt_C, d_y_wrt_C] + [None for _ in idxs]
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
if None in eval_points[:2]: if None in eval_points[:2]:
return [None] return [None]
......
...@@ -2605,6 +2605,32 @@ class T_subtensor(unittest.TestCase): ...@@ -2605,6 +2605,32 @@ class T_subtensor(unittest.TestCase):
val = f() val = f()
self.assertTrue(numpy.allclose(val, data[idx].shape)) self.assertTrue(numpy.allclose(val, data[idx].shape))
def test_grad_advanced_inc_subtensor(self):
def inc_slice(*s):
def just_numeric_args(a,b):
cost = (a[s] + b).sum()
cost_wrt_a = grad(cost, a)
cost_wrt_b = grad(cost, b)
grads = cost_wrt_a.sum() + cost_wrt_b.sum()
return grads
return just_numeric_args
# vector
utt.verify_grad(
inc_slice(slice(2, 4, None)),
(numpy.asarray([0, 1, 2, 3, 4, 5.]), numpy.asarray([9, 9.]),))
# matrix
utt.verify_grad(
inc_slice(slice(1, 2, None), slice(None, None, None)),
(numpy.asarray([[0, 1], [2, 3], [4, 5.]]),
numpy.asarray([[9, 9.]]),))
#single element
utt.verify_grad(
inc_slice(2, 1),
(numpy.asarray([[0, 1],[2, 3],[4, 5.]]), numpy.asarray(9.),))
class TestIncSubtensor1(unittest.TestCase): class TestIncSubtensor1(unittest.TestCase):
# test inc_subtensor # test inc_subtensor
......
...@@ -87,6 +87,7 @@ class Test_inc_subtensor(unittest.TestCase): ...@@ -87,6 +87,7 @@ class Test_inc_subtensor(unittest.TestCase):
expected_result[:,sl3,:val_sl2_end] += val_inc expected_result[:,sl3,:val_sl2_end] += val_inc
self.assertTrue(numpy.array_equal(result, expected_result)) self.assertTrue(numpy.array_equal(result, expected_result))
def test_grad_inc_set(self): def test_grad_inc_set(self):
def inc_slice(*s): def inc_slice(*s):
def just_numeric_args(a,b): def just_numeric_args(a,b):
...@@ -116,3 +117,4 @@ class Test_inc_subtensor(unittest.TestCase): ...@@ -116,3 +117,4 @@ class Test_inc_subtensor(unittest.TestCase):
f_slice(2, 1), f_slice(2, 1),
(numpy.asarray([[0,1],[2,3],[4,5.]]), (numpy.asarray([[0,1],[2,3],[4,5.]]),
numpy.asarray(9.),)) numpy.asarray(9.),))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论