提交 7e01a063 authored 作者: Frederic Bastien's avatar Frederic Bastien

Small refactorization of test to remove duplicate code.

上级 1f880996
......@@ -87,60 +87,32 @@ class Test_inc_subtensor(unittest.TestCase):
expected_result[:,sl3,:val_sl2_end] += val_inc
self.assertTrue(numpy.array_equal(result, expected_result))
def test_grad_inc(self):
a = T.dvector()
b = T.dvector()
def test_grad_inc_set(self):
def inc_slice(*s):
def just_numeric_args(a,b):
return T.inc_subtensor(a[s], b)
return just_numeric_args
# vector
utt.verify_grad(
inc_slice(slice(2,4,None)),
(numpy.asarray([0,1,2,3,4,5.]),
numpy.asarray([9,9.]),))
# matrix
utt.verify_grad(
inc_slice(slice(1,2,None), slice(None, None, None)),
(numpy.asarray([[0,1],[2,3],[4,5.]]),
numpy.asarray([[9,9.]]),))
#single element
utt.verify_grad(
inc_slice(2, 1),
(numpy.asarray([[0,1],[2,3],[4,5.]]),
numpy.asarray(9.),))
def test_grad_set(self):
a = T.dvector()
b = T.dvector()
def set_slice(*s):
def just_numeric_args(a,b):
return T.set_subtensor(a[s], b)
return just_numeric_args
# vector
utt.verify_grad(
set_slice(slice(2,4,None)),
(numpy.asarray([0,1,2,3,4,5.]),
numpy.asarray([9,9.]),))
# matrix
utt.verify_grad(
set_slice(slice(1,2,None), slice(None, None, None)),
(numpy.asarray([[0,1],[2,3],[4,5.]]),
numpy.asarray([[9,9.]]),))
#single element
utt.verify_grad(
set_slice(2, 1),
(numpy.asarray([[0,1],[2,3],[4,5.]]),
numpy.asarray(9.),))
for f_slice in [inc_slice, set_slice]:
# vector
utt.verify_grad(
f_slice(slice(2,4,None)),
(numpy.asarray([0,1,2,3,4,5.]),
numpy.asarray([9,9.]),))
# matrix
utt.verify_grad(
f_slice(slice(1,2,None), slice(None, None, None)),
(numpy.asarray([[0,1],[2,3],[4,5.]]),
numpy.asarray([[9,9.]]),))
#single element
utt.verify_grad(
f_slice(2, 1),
(numpy.asarray([[0,1],[2,3],[4,5.]]),
numpy.asarray(9.),))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论