提交 e28dc04e authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Check for indexing with Python bools (unsupported).

上级 08d90db3
...@@ -399,6 +399,14 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -399,6 +399,14 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
assert_array_equal(numpy_n4[numpy_n > 2, ...], n4[n > 2, ...].eval()) assert_array_equal(numpy_n4[numpy_n > 2, ...], n4[n > 2, ...].eval())
assert_array_equal(numpy_n4[numpy_n > 2, ..., 1], n4[n > 2, ..., 1].eval()) assert_array_equal(numpy_n4[numpy_n > 2, ..., 1], n4[n > 2, ..., 1].eval())
# special cases: Python bools and bools nested in Python arrays are not supported
self.assertRaises(TypeError, n.__getitem__, (True,))
self.assertRaises(TypeError, n.__getitem__, (False,))
self.assertRaises(TypeError, n.__getitem__, (True, False))
self.assertRaises(TypeError, n.__getitem__, ([True, False]))
self.assertRaises(TypeError, n.__getitem__, ([0, 1], [0, False]))
self.assertRaises(TypeError, n.__getitem__, ([0, 1], [0, theano.shared(True)]))
def test_newaxis(self): def test_newaxis(self):
""" """
newaxis support comes from logic in the __getitem__ of TensorType newaxis support comes from logic in the __getitem__ of TensorType
......
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import collections
import copy import copy
import traceback as tb import traceback as tb
import warnings import warnings
...@@ -458,6 +459,18 @@ class _tensor_py_operators(object): ...@@ -458,6 +459,18 @@ class _tensor_py_operators(object):
# SLICING/INDEXING # SLICING/INDEXING
def __getitem__(self, args): def __getitem__(self, args):
def includes_bool(args_el):
if (isinstance(args_el, (np.bool_, bool)) or
(hasattr(args_el, 'dtype') and args_el.dtype == 'bool')):
return True
if (not isinstance(args_el, theano.tensor.Variable) and
isinstance(args_el, collections.Iterable)):
for el in args_el:
if includes_bool(el):
return True
return False
if (isinstance(args, list) and if (isinstance(args, list) and
any([isinstance(a, slice) for a in args])): any([isinstance(a, slice) for a in args])):
pass pass
...@@ -467,11 +480,25 @@ class _tensor_py_operators(object): ...@@ -467,11 +480,25 @@ class _tensor_py_operators(object):
# Convert boolean arrays to calls to mask.nonzero() # Convert boolean arrays to calls to mask.nonzero()
tmp_args = [] tmp_args = []
for arg in args: for arg in args:
# NumPy arrays or tensors of type bool can be converted to
# normal integer indices.
if (isinstance(arg, (np.ndarray, theano.tensor.Variable)) and if (isinstance(arg, (np.ndarray, theano.tensor.Variable)) and
hasattr(arg, 'dtype') and hasattr(arg, 'nonzero') and hasattr(arg, 'dtype') and hasattr(arg, 'nonzero') and
arg.dtype == 'bool'): arg.dtype == 'bool'):
tmp_args += arg.nonzero() tmp_args += arg.nonzero()
else: else:
# Python arrays can contain a mixture of bools and integers,
# which requires complex rules to handle all special cases.
# These rules differ slightly between NumPy versions.
# Since earlier versions of Theano did not support any boolean
# indexing, it is safe to throw an error if we encounter
# any of these difficult cases.
if includes_bool(arg):
raise TypeError('TensorType does not support Python bools '
'for indexing, such as tensor[[True, False]]. '
'To use a boolean mask, convert the mask to '
'a NumPy array first, e.g., '
'tensor[numpy.array([True, False])].')
tmp_args.append(arg) tmp_args.append(arg)
args = tuple(tmp_args) args = tuple(tmp_args)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论