提交 f5207388 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #4853 from lamblin/Ellipsis

Allow indexing using Ellipsis.
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, division, print_function
import logging import logging
import sys import sys
import unittest import unittest
from nose.plugins.skip import SkipTest
import numpy import numpy
from nose.plugins.skip import SkipTest
from nose.tools import assert_equal
from numpy.testing import assert_array_equal
from six import StringIO from six import StringIO
from six.moves import xrange from six.moves import xrange
import theano import theano
from theano.compat import exc_message, izip, PY3
from theano.compile import DeepCopyOp
from theano import config
from theano import gof
import theano.scalar as scal import theano.scalar as scal
import theano.tensor as tensor import theano.tensor as tensor
from theano.tests import unittest_tools as utt from theano import config, gof
from theano.tensor.subtensor import (inc_subtensor, set_subtensor, from theano.compat import PY3, exc_message, izip
from theano.compile import DeepCopyOp
from theano.tensor import (MakeSlice, NotScalarConstantError, _shared,
as_tensor_variable, cscalar, ctensor3, dmatrix,
dscalar, dtensor4, dvector, fmatrix, fscalar,
fvector, iscalar, lmatrix, lrow, lvector, matrix,
vector)
from theano.tensor.basic import DimShuffle
from theano.tensor.subtensor import (AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedSubtensor,
AdvancedSubtensor1, IncSubtensor,
Subtensor, advanced_inc_subtensor,
advanced_inc_subtensor1, advanced_inc_subtensor1,
advanced_set_subtensor1,
advanced_inc_subtensor,
advanced_set_subtensor, advanced_set_subtensor,
Subtensor, IncSubtensor, advanced_set_subtensor1,
AdvancedSubtensor1, AdvancedSubtensor, advanced_subtensor1,
advanced_subtensor1, inplace_increment, get_canonical_form_slice, inc_subtensor,
AdvancedIncSubtensor1, inplace_increment, set_subtensor)
AdvancedIncSubtensor, from theano.tensor.tests.test_basic import inplace_func, rand, randint_ranged
get_canonical_form_slice) from theano.tests import unittest_tools as utt
from theano.tensor import (as_tensor_variable, _shared,
NotScalarConstantError,
fscalar, iscalar, dscalar, cscalar,
vector, dvector, fvector, lvector, lrow,
fmatrix, dmatrix, lmatrix, matrix,
ctensor3, dtensor4)
from theano.tensor.tests.test_basic import rand, randint_ranged, inplace_func
from theano.tests.unittest_tools import attr from theano.tests.unittest_tools import attr
if PY3: if PY3:
...@@ -65,6 +66,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -65,6 +66,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
self.adv_incsub1 = adv_incsub1 self.adv_incsub1 = adv_incsub1
if mode is None: if mode is None:
mode = theano.compile.mode.get_default_mode() mode = theano.compile.mode.get_default_mode()
mode = mode.including("local_useless_subtensor")
self.mode = mode self.mode = mode
self.dtype = dtype self.dtype = dtype
self.type = type self.type = type
...@@ -97,18 +99,18 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -97,18 +99,18 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
Subtensor.debug = False Subtensor.debug = False
utt.seed_rng() utt.seed_rng()
def eval_output_and_check(self, t, list=False, mode=None): def eval_output_and_check(self, t, op_type=None, mode=None, length=1):
if op_type is None:
op_type = self.sub
if mode is None: if mode is None:
mode = self.mode mode = self.mode
f = inplace_func([], t, mode=mode) f = inplace_func([], t, mode=mode)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
topo_ = [node for node in topo if not isinstance(node.op, topo_ = [node for node in topo if not isinstance(node.op,
self.ignore_topo)] self.ignore_topo)]
assert len(topo_) == 1 assert_equal(len(topo_), length)
if not list: if length == 1:
assert isinstance(topo_[0].op, self.sub) assert isinstance(topo_[0].op, op_type)
else:
assert isinstance(topo_[0].op, self.adv_sub1)
tval = f() tval = f()
return tval return tval
...@@ -337,6 +339,36 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -337,6 +339,36 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
ret = f() ret = f()
assert ret.shape == (1, 1, 4) assert ret.shape == (1, 1, 4)
def test_ellipsis(self):
numpy_n = numpy.arange(24, dtype=self.dtype).reshape((2, 3, 4))
n = self.shared(numpy_n)
test_cases = [
(0, self.sub, numpy.index_exp[...]),
(1, self.sub, numpy.index_exp[..., 1]),
(1, self.sub, numpy.index_exp[1, ...]),
(1, self.sub, numpy.index_exp[..., 1, 2, 3]),
(1, self.sub, numpy.index_exp[1, ..., 2, 3]),
(1, self.sub, numpy.index_exp[1, 2, 3, ...]),
(3, DimShuffle, numpy.index_exp[..., [0, 2, 3]]),
(1, DimShuffle,
numpy.index_exp[numpy.newaxis, ...])]
# The following test case is not supported by numpy before 1.9
numpy_version = [int(v) for v in numpy.version.version.split('.')[0:2]]
if numpy_version >= [1, 9]:
test_cases.append(
(1, AdvancedSubtensor,
numpy.index_exp[..., numpy.newaxis, [1, 2]]))
for length, op_type, slice_ in test_cases:
numpy_tval = numpy_n[slice_]
t = n[slice_]
self.assertTrue(isinstance(t.owner.op, op_type))
tval = self.eval_output_and_check(t,
op_type=op_type,
length=length)
assert_equal(tval.shape, numpy_tval.shape)
assert_array_equal(tval, numpy_tval)
def test_newaxis(self): def test_newaxis(self):
""" """
newaxis support comes from logic in the __getitem__ of TensorType newaxis support comes from logic in the __getitem__ of TensorType
...@@ -433,7 +465,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -433,7 +465,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
topo_ = [node for node in topo if not isinstance(node.op, topo_ = [node for node in topo if not isinstance(node.op,
self.ignore_topo)] self.ignore_topo)]
if not self.fast_compile: if not self.fast_compile:
assert len(topo_) == 6 assert_equal(len(topo_), 6)
assert numpy.sum([isinstance(node.op, self.inc_sub) assert numpy.sum([isinstance(node.op, self.inc_sub)
for node in topo_]) == 1 for node in topo_]) == 1
assert numpy.sum([isinstance(node.op, self.sub) assert numpy.sum([isinstance(node.op, self.sub)
...@@ -467,7 +499,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -467,7 +499,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
# We test again AdvancedSubtensor1 as we transfer data to the cpu. # We test again AdvancedSubtensor1 as we transfer data to the cpu.
self.assertTrue(isinstance(t.owner.op, tensor.AdvancedSubtensor1)) self.assertTrue(isinstance(t.owner.op, tensor.AdvancedSubtensor1))
val = self.eval_output_and_check(t, list=True) val = self.eval_output_and_check(t, op_type=self.adv_sub1)
if isinstance(idx, list): if isinstance(idx, list):
good = data[idx] good = data[idx]
else: else:
......
...@@ -471,6 +471,24 @@ class _tensor_py_operators(object): ...@@ -471,6 +471,24 @@ class _tensor_py_operators(object):
pass pass
elif not isinstance(args, tuple): elif not isinstance(args, tuple):
args = args, args = args,
# Convert an Ellipsis if provided into an appropriate number of
# slice(None).
ellipses = [i
for i, index in enumerate(args)
if index is Ellipsis]
if len(ellipses) > 1:
raise IndexError(
"an index can only have a single Ellipsis (`...`)")
elif len(ellipses) == 1:
new_axes = sum(1
for index in args
if index is numpy.newaxis) # numpy.newaxis is None
ellipsis_at = ellipses[0]
args = list(args)
args[ellipsis_at: ellipsis_at + 1] = (
[slice(None)] * (self.ndim - (len(args) - 1 - new_axes)))
# Convert python literals to theano constants # Convert python literals to theano constants
args = theano.tensor.subtensor.make_constant(args) args = theano.tensor.subtensor.make_constant(args)
# Determine if advanced indexing is needed or not # Determine if advanced indexing is needed or not
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论