提交 68162534 authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Little fixes to improve boolean mask indexing.

上级 1fbb77b6
......@@ -336,6 +336,11 @@ class Subtensor(Op):
theano.tensor.wscalar, theano.tensor.bscalar]
invalid_tensor_types = [theano.tensor.fscalar, theano.tensor.dscalar,
theano.tensor.cscalar, theano.tensor.zscalar]
if (isinstance(entry, (np.ndarray, theano.tensor.Variable)) and
hasattr(entry, 'dtype') and entry.dtype == 'bool'):
raise AdvancedIndexingError(Subtensor.e_indextype, entry)
if (isinstance(entry, gof.Variable) and
(entry.type in invalid_scal_types or
entry.type in invalid_tensor_types)):
......@@ -2049,8 +2054,8 @@ def as_index_variable(idx):
if isinstance(idx, gof.Variable) and isinstance(idx.type, NoneTypeT):
return idx
idx = theano.tensor.as_tensor_variable(idx)
if idx.type.dtype not in theano.tensor.integer_dtypes and idx.type.dtype != 'bool':
raise TypeError('index must be integers')
if idx.type.dtype not in theano.tensor.discrete_dtypes:
raise TypeError('index must be integers or a boolean mask')
return idx
......
......@@ -385,7 +385,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
assert_array_equal(numpy_n[np.newaxis, 1:, mask], n[np.newaxis, 1:, mask].eval())
assert_array_equal(numpy_inc_subtensor(numpy_n, [0, mask], 1),
inc_subtensor(n[0, mask], 1).eval())
assert_array_equal(numpy_inc_subtensor(numpy_n, [Ellipsis, mask], 1),
assert_array_equal(numpy_inc_subtensor(numpy_n, [slice(None), mask], 1),
inc_subtensor(n[:, mask], 1).eval())
# indexing with a boolean ndarray
......
......@@ -534,12 +534,7 @@ class _tensor_py_operators(object):
axis = None
for i, arg in enumerate(args):
try:
if (isinstance(arg, (np.ndarray, theano.tensor.Variable)) and
hasattr(arg, 'dtype') and arg.dtype == 'bool'):
advanced = True
axis = None
break
elif arg is not np.newaxis:
if arg is not np.newaxis:
theano.tensor.subtensor.Subtensor.convert(arg)
except theano.tensor.subtensor.AdvancedIndexingError:
if advanced:
......@@ -555,6 +550,7 @@ class _tensor_py_operators(object):
equal_slices(a, slice(None)) for a in args[:axis]) and
all(isinstance(a, slice) and
equal_slices(a, slice(None)) for a in args[axis + 1:]) and
(not hasattr(args[axis], 'dtype') or args[axis].dtype != 'bool') and
isinstance(args[axis],
(np.ndarray, list,
TensorVariable, TensorConstant,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论