提交 fc688a28 authored 作者: Frederic's avatar Frederic

pep8

上级 ce0a8359
......@@ -4,6 +4,7 @@ from theano.tests import unittest_tools as utt
import theano
import theano.tensor as T
class Test_inc_subtensor(unittest.TestCase):
"""Partial testing.
......@@ -14,8 +15,10 @@ class Test_inc_subtensor(unittest.TestCase):
- indices: scalar vs slice, constant vs variable, out of bound, ...
- inplace
NOTE: these are the same tests as test_incsubtensor.py, but using the new (read: not
deprecated) inc_subtensor, set_subtensor functions.
NOTE: these are the same tests as test_incsubtensor.py, but using
the new (read: not deprecated) inc_subtensor, set_subtensor
functions.
"""
def setUp(self):
utt.seed_rng()
......@@ -30,7 +33,7 @@ class Test_inc_subtensor(unittest.TestCase):
sl2_end = T.lscalar()
sl2 = slice(sl2_end)
for do_set in [False,True]:
for do_set in [False, True]:
if do_set:
resut = T.set_subtensor(a[sl1, sl2], increment)
......@@ -39,7 +42,7 @@ class Test_inc_subtensor(unittest.TestCase):
f = theano.function([a, increment, sl2_end], resut)
val_a = numpy.ones((5,5))
val_a = numpy.ones((5, 5))
val_inc = 2.3
val_sl2_end = 2
......@@ -47,9 +50,9 @@ class Test_inc_subtensor(unittest.TestCase):
expected_result = numpy.copy(val_a)
if do_set:
expected_result[:,:val_sl2_end] = val_inc
expected_result[:, :val_sl2_end] = val_inc
else:
expected_result[:,:val_sl2_end] += val_inc
expected_result[:, :val_sl2_end] += val_inc
self.assertTrue(numpy.array_equal(result, expected_result))
......@@ -64,7 +67,7 @@ class Test_inc_subtensor(unittest.TestCase):
sl2 = slice(sl2_end)
sl3 = 2
for do_set in [True,False]:
for do_set in [True, False]:
print "Set", do_set
if do_set:
......@@ -74,7 +77,7 @@ class Test_inc_subtensor(unittest.TestCase):
f = theano.function([a, increment, sl2_end], resut)
val_a = numpy.ones((5,3,4))
val_a = numpy.ones((5, 3, 4))
val_inc = 2.3
val_sl2_end = 2
......@@ -82,39 +85,39 @@ class Test_inc_subtensor(unittest.TestCase):
result = f(val_a, val_inc, val_sl2_end)
if do_set:
expected_result[:,sl3,:val_sl2_end] = val_inc
expected_result[:, sl3, :val_sl2_end] = val_inc
else:
expected_result[:,sl3,:val_sl2_end] += val_inc
expected_result[:, sl3, :val_sl2_end] += val_inc
self.assertTrue(numpy.array_equal(result, expected_result))
def test_grad_inc_set(self):
def inc_slice(*s):
def just_numeric_args(a,b):
def just_numeric_args(a, b):
return T.inc_subtensor(a[s], b)
return just_numeric_args
def set_slice(*s):
def just_numeric_args(a,b):
def just_numeric_args(a, b):
return T.set_subtensor(a[s], b)
return just_numeric_args
for f_slice in [inc_slice, set_slice]:
# vector
utt.verify_grad(
f_slice(slice(2,4,None)),
(numpy.asarray([0,1,2,3,4,5.]),
numpy.asarray([9,9.]),))
f_slice(slice(2, 4, None)),
(numpy.asarray([0, 1, 2, 3, 4, 5.]),
numpy.asarray([9, 9.]), ))
# matrix
utt.verify_grad(
f_slice(slice(1,2,None), slice(None, None, None)),
(numpy.asarray([[0,1],[2,3],[4,5.]]),
numpy.asarray([[9,9.]]),))
f_slice(slice(1, 2, None), slice(None, None, None)),
(numpy.asarray([[0, 1], [2, 3], [4, 5.]]),
numpy.asarray([[9, 9.]]), ))
#single element
utt.verify_grad(
f_slice(2, 1),
(numpy.asarray([[0,1],[2,3],[4,5.]]),
(numpy.asarray([[0, 1], [2, 3], [4, 5.]]),
numpy.asarray(9.),))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论