提交 aa19b945 authored 作者: Justin Bayer's avatar Justin Bayer

PEP8 formatting and correct name for test.

上级 426dd4a3
...@@ -2605,7 +2605,7 @@ class T_subtensor(unittest.TestCase): ...@@ -2605,7 +2605,7 @@ class T_subtensor(unittest.TestCase):
val = f() val = f()
self.assertTrue(numpy.allclose(val, data[idx].shape)) self.assertTrue(numpy.allclose(val, data[idx].shape))
def test_gradgrad_advanced_inc_subtensor(self): def test_grad_advanced_inc_subtensor(self):
def inc_slice(*s): def inc_slice(*s):
def just_numeric_args(a,b): def just_numeric_args(a,b):
cost = (a[s] + b).sum() cost = (a[s] + b).sum()
...@@ -2614,24 +2614,22 @@ class T_subtensor(unittest.TestCase): ...@@ -2614,24 +2614,22 @@ class T_subtensor(unittest.TestCase):
grads = cost_wrt_a.sum() + cost_wrt_b.sum() grads = cost_wrt_a.sum() + cost_wrt_b.sum()
return grads return grads
return just_numeric_args return just_numeric_args
# vector # vector
utt.verify_grad( utt.verify_grad(
inc_slice(slice(2,4,None)), inc_slice(slice(2, 4, None)),
(numpy.asarray([0,1,2,3,4,5.]), (numpy.asarray([0, 1, 2, 3, 4, 5.]), numpy.asarray([9, 9.]),))
numpy.asarray([9,9.]),))
# matrix # matrix
utt.verify_grad( utt.verify_grad(
inc_slice(slice(1,2,None), slice(None, None, None)), inc_slice(slice(1, 2, None), slice(None, None, None)),
(numpy.asarray([[0,1],[2,3],[4,5.]]), (numpy.asarray([[0, 1], [2, 3], [4, 5.]]),
numpy.asarray([[9,9.]]),)) numpy.asarray([[9, 9.]]),))
#single element #single element
utt.verify_grad( utt.verify_grad(
inc_slice(2, 1), inc_slice(2, 1),
(numpy.asarray([[0,1],[2,3],[4,5.]]), (numpy.asarray([[0, 1],[2, 3],[4, 5.]]), numpy.asarray(9.),))
numpy.asarray(9.),))
class TestIncSubtensor1(unittest.TestCase): class TestIncSubtensor1(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论