提交 01c4db07 authored 作者: John Salvatier's avatar John Salvatier

advanced indexing and incrementing tests

上级 1bd8e47a
......@@ -3725,6 +3725,69 @@ class TestIncSubtensor1(unittest.TestCase):
self.assertRaises(TypeError,
lambda: inc_subtensor(self.v[self.adv1q], fmatrix()))
class TestAdvancedSubtensor(unittest.TestCase):
# test inc_subtensor
# also tests set_subtensor
def setUp(self):
self.s = iscalar()
self.v = fvector()
self.m = dmatrix()
self.t = ctensor3()
self.ix1 = lvector() # advanced 1d query
self.ix12 = lvector()
self.ix2 = lmatrix()
def test_cant_adv_idx_into_scalar(self):
self.assertRaises(TypeError, lambda: self.s[self.ix1])
def test_index_into_vec_w_vec(self):
a = self.v[self.ix1]
assert a.type == self.v.type
def test_index_into_vec_w_matrix(self):
a = self.v[self.ix2]
def test_inc_adv_selection(self):
a = inc_subtensor(self.v[self.ix2], self.v[self.ix2])
assert a.type == self.v.type
f = theano.function([self.v, self.ix2], a, allow_input_downcast=True)
aval = f([.4, .9, .1], [[1, 2],
[1, 2]])
assert numpy.allclose(aval, [.4, .9*3, .1 * 3])
def test_inc_adv_selection2(self):
subt = self.m[self.ix1,self.ix12]
a = inc_subtensor(subt, subt)
assert a.type == self.m.type, str(a.type) +str(a.type.broadcastable) + " " + str(self.m.type) + str(self.m.type.broadcastable)
f = theano.function([self.m, self.ix1, self.ix12], a, allow_input_downcast=True)
aval = f([[.4, .9, .1],
[5, 6, 7],
[.5, .3, .15]],
[1, 2, 1], [0,1,0])
assert numpy.allclose(aval,
[[.4, .9, .1],
[5*3, 6, 7],
[.5, .3*2, .15]])
def test_inc_adv_selection_with_broadcasting(self):
a = inc_subtensor(self.m[self.ix1,self.ix12], 2.1)
assert a.type == self.m.type
f = theano.function([self.m, self.ix1, self.ix12], a, allow_input_downcast=True)
aval = f([[.4, .9, .1],
[5, 6, 7],
[.5, .3, .15]],
[1, 2, 1], [0,1,0])
assert numpy.allclose(aval,
[[.4, .9, .1],
[5+2.1*2, 6, 7],
[.5, .3 + 2.1, .15]])
class T_Join_and_Split(unittest.TestCase):
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论