提交 296fabec authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add test to make sure that we shift things correctly for a single matrix with removed dimensions.

上级 9a9401c2
...@@ -20,8 +20,8 @@ from theano.compile import DeepCopyOp ...@@ -20,8 +20,8 @@ from theano.compile import DeepCopyOp
from theano.tensor import (MakeSlice, NotScalarConstantError, _shared, from theano.tensor import (MakeSlice, NotScalarConstantError, _shared,
as_tensor_variable, cscalar, ctensor3, dmatrix, as_tensor_variable, cscalar, ctensor3, dmatrix,
dscalar, dtensor4, dvector, fmatrix, fscalar, dscalar, dtensor4, dvector, fmatrix, fscalar,
fvector, iscalar, lmatrix, lrow, lvector, matrix, fvector, ftensor4, iscalar, lmatrix, lrow, lvector,
vector) matrix, vector)
from theano.tensor.basic import DimShuffle from theano.tensor.basic import DimShuffle
from theano.tensor.subtensor import (AdvancedIncSubtensor, from theano.tensor.subtensor import (AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedSubtensor, AdvancedIncSubtensor1, AdvancedSubtensor,
...@@ -1349,6 +1349,7 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -1349,6 +1349,7 @@ class TestAdvancedSubtensor(unittest.TestCase):
self.v = fvector() self.v = fvector()
self.m = dmatrix() self.m = dmatrix()
self.t = ctensor3() self.t = ctensor3()
self.ft4 = ftensor4()
self.ix1 = lvector() # advanced 1d query self.ix1 = lvector() # advanced 1d query
self.ix12 = lvector() self.ix12 = lvector()
...@@ -1419,11 +1420,21 @@ class TestAdvancedSubtensor(unittest.TestCase): ...@@ -1419,11 +1420,21 @@ class TestAdvancedSubtensor(unittest.TestCase):
a = inc_subtensor(subt, subt) a = inc_subtensor(subt, subt)
assert a.type == self.v.type, (a.type, self.v.type) assert a.type == self.v.type, (a.type, self.v.type)
f = theano.function([self.v, self.ix2], a, allow_input_downcast=True) f = theano.function([self.v, self.ix2], a, allow_input_downcast=True,
mode=self.mode)
aval = f([.4, .9, .1], [[1, 2], aval = f([.4, .9, .1], [[1, 2],
[1, 2]]) [1, 2]])
assert numpy.allclose(aval, [.4, .9 * 3, .1 * 3]) assert numpy.allclose(aval, [.4, .9 * 3, .1 * 3])
def test_adv_subtensor_w_int_and_matrix(self):
subt = self.ft4[0, :, self.ix2, :]
f = theano.function([self.ft4, self.ix2], subt, mode=self.mode)
ft4v = numpy.random.random((2, 3, 4, 5)).astype('float32')
ix2v = numpy.asarray([[0, 1], [1, 0]])
aval = f(ft4v, ix2v)
rval = ft4v[0, :, ix2v, :]
utt.assert_allclose(rval, aval)
def test_inc_adv_subtensor_w_2vec(self): def test_inc_adv_subtensor_w_2vec(self):
if inplace_increment is None: if inplace_increment is None:
raise inplace_increment_missing raise inplace_increment_missing
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论