提交 80b579b8 authored 作者: Sina Honari's avatar Sina Honari

adding advanced indexing to theano tensors

上级 77c4f4d1
...@@ -1429,3 +1429,24 @@ class TestInferShape(utt.InferShapeTester): ...@@ -1429,3 +1429,24 @@ class TestInferShape(utt.InferShapeTester):
self._compile_and_check([admat, advec], self._compile_and_check([admat, advec],
[set_subtensor(admat[aivec_val, bivec_val], advec)], [set_subtensor(admat[aivec_val, bivec_val], advec)],
[admat_val, advec_val], AdvancedIncSubtensor) [admat_val, advec_val], AdvancedIncSubtensor)
def test_advanced_indexing():
# tests advanced indexing in Theano for 2D and 3D tensors
rng = numpy.random.RandomState(utt.seed_rng())
a = rng.uniform(size=(3,3))
b = theano.shared(a)
i = T.iscalar()
j = T.iscalar()
z = b[[i, j], :]
f1 = theano.function([i,j],z)
cmd = f1(0,1) == a[[0,1],:]
numpy.all(cmp)
aa = rng.uniform(size=(4,2,3))
bb = theano.shared(aa)
k = T.iscalar()
z = bb[[i, j, k],:, i:k]
f2 = theano.function([i,j,k],z)
cmd = f2(0,1,2) == aa[[0,1,2],:, 0:2]
numpy.all(cmp)
...@@ -350,6 +350,11 @@ class _tensor_py_operators: ...@@ -350,6 +350,11 @@ class _tensor_py_operators:
# argument slice(1, None, None), which is much more desirable. # argument slice(1, None, None), which is much more desirable.
# __getslice__ is deprecated in python 2.6 anyway. # __getslice__ is deprecated in python 2.6 anyway.
def equal_slices(self, s1, s2):
return (s1.start == s2.start and
s1.stop == s2.stop and
s1.step == s2.step)
def __getitem__(self, args): def __getitem__(self, args):
if not isinstance(args, tuple): if not isinstance(args, tuple):
args = args, args = args,
...@@ -375,15 +380,15 @@ class _tensor_py_operators: ...@@ -375,15 +380,15 @@ class _tensor_py_operators:
if advanced: if advanced:
if (axis is not None if (axis is not None
and all(a == slice(None) for a in args[:axis]) and all(self.qual_slices(a, slice(None)) for a in args[:axis])
and all(a == slice(None) for a in args[axis + 1:]) and all(self.equal_slices(a, slice(None)) for a in args[axis + 1:])
and isinstance(args[axis], ( and isinstance(args[axis], (
numpy.ndarray, numpy.ndarray,
list, list,
TensorVariable, TensorVariable,
TensorConstant, TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable))): theano.tensor.sharedvar.TensorSharedVariable))):
return self.take(arg, axis) return self.take(args[axis], axis)
else: else:
return theano.tensor.subtensor.advanced_subtensor(self, *args) return theano.tensor.subtensor.advanced_subtensor(self, *args)
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论