提交 f5207388 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4853 from lamblin/Ellipsis

Allow indexing using Ellipsis.
from __future__ import absolute_import, print_function, division
from __future__ import absolute_import, division, print_function
import logging
import sys
import unittest
from nose.plugins.skip import SkipTest
import numpy
from nose.plugins.skip import SkipTest
from nose.tools import assert_equal
from numpy.testing import assert_array_equal
from six import StringIO
from six.moves import xrange
import theano
from theano.compat import exc_message, izip, PY3
from theano.compile import DeepCopyOp
from theano import config
from theano import gof
import theano.scalar as scal
import theano.tensor as tensor
from theano.tests import unittest_tools as utt
from theano.tensor.subtensor import (inc_subtensor, set_subtensor,
from theano import config, gof
from theano.compat import PY3, exc_message, izip
from theano.compile import DeepCopyOp
from theano.tensor import (MakeSlice, NotScalarConstantError, _shared,
as_tensor_variable, cscalar, ctensor3, dmatrix,
dscalar, dtensor4, dvector, fmatrix, fscalar,
fvector, iscalar, lmatrix, lrow, lvector, matrix,
vector)
from theano.tensor.basic import DimShuffle
from theano.tensor.subtensor import (AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedSubtensor,
AdvancedSubtensor1, IncSubtensor,
Subtensor, advanced_inc_subtensor,
advanced_inc_subtensor1,
advanced_set_subtensor1,
advanced_inc_subtensor,
advanced_set_subtensor,
Subtensor, IncSubtensor,
AdvancedSubtensor1, AdvancedSubtensor,
advanced_subtensor1, inplace_increment,
AdvancedIncSubtensor1,
AdvancedIncSubtensor,
get_canonical_form_slice)
from theano.tensor import (as_tensor_variable, _shared,
NotScalarConstantError,
fscalar, iscalar, dscalar, cscalar,
vector, dvector, fvector, lvector, lrow,
fmatrix, dmatrix, lmatrix, matrix,
ctensor3, dtensor4)
from theano.tensor.tests.test_basic import rand, randint_ranged, inplace_func
advanced_set_subtensor1,
advanced_subtensor1,
get_canonical_form_slice, inc_subtensor,
inplace_increment, set_subtensor)
from theano.tensor.tests.test_basic import inplace_func, rand, randint_ranged
from theano.tests import unittest_tools as utt
from theano.tests.unittest_tools import attr
if PY3:
......@@ -65,6 +66,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self.adv_incsub1 = adv_incsub1
if mode is None:
mode = theano.compile.mode.get_default_mode()
mode = mode.including("local_useless_subtensor")
self.mode = mode
self.dtype = dtype
self.type = type
......@@ -97,18 +99,18 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
Subtensor.debug = False
utt.seed_rng()
def eval_output_and_check(self, t, list=False, mode=None):
def eval_output_and_check(self, t, op_type=None, mode=None, length=1):
if op_type is None:
op_type = self.sub
if mode is None:
mode = self.mode
f = inplace_func([], t, mode=mode)
topo = f.maker.fgraph.toposort()
topo_ = [node for node in topo if not isinstance(node.op,
self.ignore_topo)]
assert len(topo_) == 1
if not list:
assert isinstance(topo_[0].op, self.sub)
else:
assert isinstance(topo_[0].op, self.adv_sub1)
assert_equal(len(topo_), length)
if length == 1:
assert isinstance(topo_[0].op, op_type)
tval = f()
return tval
......@@ -337,6 +339,36 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
ret = f()
assert ret.shape == (1, 1, 4)
def test_ellipsis(self):
numpy_n = numpy.arange(24, dtype=self.dtype).reshape((2, 3, 4))
n = self.shared(numpy_n)
test_cases = [
(0, self.sub, numpy.index_exp[...]),
(1, self.sub, numpy.index_exp[..., 1]),
(1, self.sub, numpy.index_exp[1, ...]),
(1, self.sub, numpy.index_exp[..., 1, 2, 3]),
(1, self.sub, numpy.index_exp[1, ..., 2, 3]),
(1, self.sub, numpy.index_exp[1, 2, 3, ...]),
(3, DimShuffle, numpy.index_exp[..., [0, 2, 3]]),
(1, DimShuffle,
numpy.index_exp[numpy.newaxis, ...])]
# The following test case is not supported by numpy before 1.9
numpy_version = [int(v) for v in numpy.version.version.split('.')[0:2]]
if numpy_version >= [1, 9]:
test_cases.append(
(1, AdvancedSubtensor,
numpy.index_exp[..., numpy.newaxis, [1, 2]]))
for length, op_type, slice_ in test_cases:
numpy_tval = numpy_n[slice_]
t = n[slice_]
self.assertTrue(isinstance(t.owner.op, op_type))
tval = self.eval_output_and_check(t,
op_type=op_type,
length=length)
assert_equal(tval.shape, numpy_tval.shape)
assert_array_equal(tval, numpy_tval)
def test_newaxis(self):
"""
newaxis support comes from logic in the __getitem__ of TensorType
......@@ -433,7 +465,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
topo_ = [node for node in topo if not isinstance(node.op,
self.ignore_topo)]
if not self.fast_compile:
assert len(topo_) == 6
assert_equal(len(topo_), 6)
assert numpy.sum([isinstance(node.op, self.inc_sub)
for node in topo_]) == 1
assert numpy.sum([isinstance(node.op, self.sub)
......@@ -467,7 +499,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
# We test again AdvancedSubtensor1 as we transfer data to the cpu.
self.assertTrue(isinstance(t.owner.op, tensor.AdvancedSubtensor1))
val = self.eval_output_and_check(t, list=True)
val = self.eval_output_and_check(t, op_type=self.adv_sub1)
if isinstance(idx, list):
good = data[idx]
else:
......
......@@ -471,6 +471,24 @@ class _tensor_py_operators(object):
pass
elif not isinstance(args, tuple):
args = args,
# Convert an Ellipsis if provided into an appropriate number of
# slice(None).
ellipses = [i
for i, index in enumerate(args)
if index is Ellipsis]
if len(ellipses) > 1:
raise IndexError(
"an index can only have a single Ellipsis (`...`)")
elif len(ellipses) == 1:
new_axes = sum(1
for index in args
if index is numpy.newaxis) # numpy.newaxis is None
ellipsis_at = ellipses[0]
args = list(args)
args[ellipsis_at: ellipsis_at + 1] = (
[slice(None)] * (self.ndim - (len(args) - 1 - new_axes)))
# Convert python literals to theano constants
args = theano.tensor.subtensor.make_constant(args)
# Determine if advanced indexing is needed or not
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论