提交 a01940c5 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Add support for boolean indexing.

上级 54c49ef0
...@@ -362,6 +362,31 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -362,6 +362,31 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
assert_equal(tval.shape, numpy_tval.shape) assert_equal(tval.shape, numpy_tval.shape)
assert_array_equal(tval, numpy_tval) assert_array_equal(tval, numpy_tval)
def test_boolean(self):
numpy_n = np.arange(6, dtype=self.dtype).reshape((2, 3))
n = self.shared(numpy_n)
assert_array_equal(numpy_n[numpy_n > 2], n[n > 2].eval())
assert_array_equal(numpy_n[[0], numpy_n[0] > 2], n[[0], n[0] > 2].eval())
assert_array_equal(numpy_n[[1], numpy_n[0] > 2], n[[1], n[0] > 2].eval())
mask = np.array([True, False])
assert_array_equal(numpy_n[mask], n[mask].eval())
mask = np.array([True, False, True])
assert_array_equal(numpy_n[0, mask], n[0, mask].eval())
assert_array_equal(numpy_n[:, mask], n[:, mask].eval())
assert_array_equal(numpy_n[:, mask], n[:, self.shared(mask)].eval())
mask = np.array([[True, False, True], [False, False, True]])
assert_array_equal(numpy_n[mask], n[mask].eval())
assert_array_equal(numpy_n[mask], n[self.shared(mask)].eval())
numpy_n4 = np.arange(48, dtype=self.dtype).reshape((2, 3, 4, 2))
n4 = self.shared(numpy_n4)
# with ellipsis
assert_array_equal(numpy_n4[numpy_n > 2, ...], n4[n > 2, ...].eval())
assert_array_equal(numpy_n4[numpy_n > 2, ..., 1], n4[n > 2, ..., 1].eval())
def test_newaxis(self): def test_newaxis(self):
""" """
newaxis support comes from logic in the __getitem__ of TensorType newaxis support comes from logic in the __getitem__ of TensorType
......
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import collections
import copy import copy
import traceback as tb import traceback as tb
import warnings import warnings
...@@ -459,31 +458,23 @@ class _tensor_py_operators(object): ...@@ -459,31 +458,23 @@ class _tensor_py_operators(object):
# SLICING/INDEXING # SLICING/INDEXING
def __getitem__(self, args): def __getitem__(self, args):
def check_bool(args_el):
try:
if (isinstance(args_el, (np.bool_, bool)) or
args_el.dtype == 'bool'):
raise TypeError('TensorType does not support boolean '
'mask for indexing such as tensor[x==0]. '
'Instead you can use non_zeros() such as '
'tensor[(x == 0).nonzeros()]. ')
except AttributeError:
pass
if (not isinstance(args_el, theano.tensor.Variable) and
isinstance(args_el, collections.Iterable)):
for el in args_el:
check_bool(el)
check_bool(args)
if (isinstance(args, list) and if (isinstance(args, list) and
any([isinstance(a, slice) for a in args])): any([isinstance(a, slice) for a in args])):
pass pass
elif not isinstance(args, tuple): elif not isinstance(args, tuple):
args = args, args = args,
# Convert boolean arrays to calls to mask.nonzero()
tmp_args = []
for arg in args:
if (isinstance(arg, (np.ndarray, TensorVariable, TensorConstant,
theano.tensor.sharedvar.TensorSharedVariable)) and
arg.dtype == 'bool'):
tmp_args += arg.nonzero()
else:
tmp_args.append(arg)
args = tuple(tmp_args)
# Convert an Ellipsis if provided into an appropriate number of # Convert an Ellipsis if provided into an appropriate number of
# slice(None). # slice(None).
ellipses = [i ellipses = [i
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论