提交 9af84336 authored 作者: nouiz's avatar nouiz

Merge pull request #698 from jaberg/subtensor_newaxis

adding support for numpy.newaxis in basic indexing
......@@ -1543,7 +1543,7 @@ class _tensor_py_operators:
advanced = False
for arg in args:
try:
Subtensor.convert(arg)
arg == numpy.newaxis or Subtensor.convert(arg)
except AdvancedIndexingError:
advanced = True
break
......@@ -1559,8 +1559,30 @@ class _tensor_py_operators:
else:
return AdvancedSubtensor()(self, *args)
else:
return Subtensor(args)(self, *Subtensor.collapse(args,
lambda entry: isinstance(entry, Variable)))
if numpy.newaxis in args:
# None (aka np.newaxis) in numpy indexing means to add a
# broadcastable dimension, which theano traditionally did with
# the dimshuffle op. The following code converts numpy-style
# indexing on self to traditional [read: implemented] theano
# indexing on a dimshuffled view of self.
counter = 0
pattern = []
new_args = []
for arg in args:
if arg == numpy.newaxis:
pattern.append('x')
new_args.append(slice(None, None, None))
else:
pattern.append(counter)
counter += 1
new_args.append(arg)
view = self.dimshuffle(pattern)
rval = view.__getitem__(tuple(new_args))
return rval
else:
return Subtensor(args)(self, *Subtensor.collapse(args,
lambda entry: isinstance(entry, Variable)))
#COPYING
def copy(self):
......
......@@ -2531,6 +2531,49 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self.assertTrue(tval.shape == ())
self.assertTrue(numpy.all(tval == 0))
def test_newaxis(self):
"""
newaxis support comes from logic in the __getitem__ of TensorType
Variables, which currently inserts dimshuffle to get the right number
of dimensions, and adjusts the slice tuple accordingly.
So testing is done via square-bracket notation rather than direct
interaction with the Subtensor Op (which has no support of its own for
newaxis).
"""
newaxis = numpy.newaxis
n = self.shared(numpy.asarray(range(24), dtype=self.dtype).reshape((2,3,4)))
assert n.ndim == 3
n4 = n[newaxis, :, :, :]
assert n4.broadcastable == (True, False, False, False), n4
n4 = n[:, newaxis, :, :]
assert n4.broadcastable == (False, True, False, False), n4
n4 = n[:, :, newaxis, :]
assert n4.broadcastable == (False, False, True, False), n4
n4 = n[:, :, :, newaxis]
assert n4.broadcastable == (False, False, False, True), n4
n3 = n.flatten()[newaxis, :, newaxis]
assert n3.broadcastable == (True, False, True), n3
s = cscalar()
s1 = s[newaxis]
assert s1.broadcastable == (True,), s1
vs1, vn3, vn4 = theano.function([s], [s1, n3, n4])(-2.0)
assert numpy.all(vs1 == [-2.0])
assert numpy.all(vn3
== numpy.arange(24)[newaxis, :, newaxis])
assert numpy.all(vn4
== numpy.arange(24).reshape((2, 3, 4))[:, :, :, newaxis])
def test_grad_1d(self):
subi = 0
data = numpy.asarray(rand(2,3), dtype=self.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论