提交 6e4511f4 authored 作者: John Salvatier's avatar John Salvatier

fix advanced inc tests

上级 33e9dbf7
......@@ -29,7 +29,7 @@ from theano.tensor import (_shared, wvector, bvector, autocast_float_as,
horizontal_stack, vertical_stack, argmax, get_vector_length,
fscalar, zeros_like, sum, tensor3, vector, add, addbroadcast,
alloc, as_tensor_variable, tensor_from_scalar, ARange, autocast_float,
clip, constant, default, dot, inc_subtensor, set_subtensor,
clip, constant, default, dot, inc_subtensor,advanced_inc_subtensor, set_subtensor,
dmatrix, dscalar, dvector, eq, eye, fill, flatten, inverse_permutation,
tensor4, permute_row_elements, Flatten, fmatrix, fscalars, grad,
inplace, iscalar, matrix, minimum, matrices, maximum, mul, neq,
......@@ -3744,7 +3744,7 @@ class TestAdvancedSubtensor(unittest.TestCase):
def test_index_into_vec_w_vec(self):
a = self.v[self.ix1]
assert a.type == self.v.type
assert a.type == self.v.type, (a.type, self.v.type)
def test_index_into_vec_w_matrix(self):
a = self.v[self.ix2]
......@@ -3752,7 +3752,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
def test_inc_adv_selection(self):
a = inc_subtensor(self.v[self.ix2], self.v[self.ix2])
assert a.type == self.v.type, (a.type,self.v.type)
typ = TensorType(self.v.type.dtype, self.ix2.type.broadcastable)
assert a.type == typ, (a.type,typ)
f = theano.function([self.v, self.ix2], a, allow_input_downcast=True)
aval = f([.4, .9, .1], [[1, 2],
[1, 2]])
......@@ -3762,7 +3763,8 @@ class TestAdvancedSubtensor(unittest.TestCase):
subt = self.m[self.ix1,self.ix12]
a = inc_subtensor(subt, subt)
assert a.type == self.m.type, (a.type, self.m.type)
typ = TensorType(self.m.type.dtype, self.ix2.type.broadcastable)
assert a.type == typ, (a.type,typ)
f = theano.function([self.m, self.ix1, self.ix12], a, allow_input_downcast=True)
aval = f([[.4, .9, .1],
[5, 6, 7],
......@@ -3771,7 +3773,7 @@ class TestAdvancedSubtensor(unittest.TestCase):
assert numpy.allclose(aval,
[[.4, .9, .1],
[5*3, 6, 7],
[.5, .3*2, .15]])
[.5, .3*2, .15]]), aval
def test_inc_adv_selection_with_broadcasting(self):
a = inc_subtensor(self.m[self.ix1,self.ix12], 2.1)
......@@ -3786,7 +3788,7 @@ class TestAdvancedSubtensor(unittest.TestCase):
assert numpy.allclose(aval,
[[.4, .9, .1],
[5+2.1*2, 6, 7],
[.5, .3 + 2.1, .15]])
[.5, .3 + 2.1, .15]]), aval
class T_Join_and_Split(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论