提交 bef883eb authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #4750 from SinaHonari/issue4595

Raising error for Indexing subtensor with a Boolean mask
...@@ -1131,18 +1131,12 @@ Theano indexing with a "mask" (incorrect approach): ...@@ -1131,18 +1131,12 @@ Theano indexing with a "mask" (incorrect approach):
.. doctest:: indexing .. doctest:: indexing
>>> t = theano.tensor.arange(9).reshape((3,3)) >>> t = theano.tensor.arange(9).reshape((3,3))
>>> t[t > 4].eval() # an array with shape (3, 3, 3) >>> t[t > 4].eval() # an array with shape (3, 3, 3) # doctest: +ELLIPSIS
array([[[0, 1, 2], Traceback (most recent call last):
[0, 1, 2], ...
[0, 1, 2]], TypeError: TensorType does not support boolean mask for indexing such as tensor[x==0].
<BLANKLINE> Instead you can use non_zeros() such as tensor[(x == 0).nonzeros()].
[[0, 1, 2], If you are indexing on purpose with an int8, please cast it to int16.
[0, 1, 2],
[3, 4, 5]],
<BLANKLINE>
[[3, 4, 5],
[3, 4, 5],
[3, 4, 5]]])
Getting a Theano result like NumPy: Getting a Theano result like NumPy:
......
...@@ -302,7 +302,7 @@ class NumpyAutocaster(object): ...@@ -302,7 +302,7 @@ class NumpyAutocaster(object):
# returns either an exact x_==x, or the last cast x_ # returns either an exact x_==x, or the last cast x_
return x_ return x_
autocast_int = NumpyAutocaster(('int8', 'int16', 'int32', 'int64')) autocast_int = NumpyAutocaster(('int16', 'int32', 'int64'))
autocast_float = NumpyAutocaster(('float16', 'float32', 'float64')) autocast_float = NumpyAutocaster(('float16', 'float32', 'float64'))
......
...@@ -696,12 +696,14 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False): ...@@ -696,12 +696,14 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False):
if minlength is not None: if minlength is not None:
max_value = theano.tensor.maximum(max_value, minlength) max_value = theano.tensor.maximum(max_value, minlength)
# Note: we do not use inc_subtensor(out[x], ...) in the following lines,
# since out[x] raises an exception if the indices (x) are int8.
if weights is None: if weights is None:
out = theano.tensor.zeros([max_value], dtype=x.dtype) out = theano.tensor.zeros([max_value], dtype=x.dtype)
out = theano.tensor.inc_subtensor(out[x], 1) out = theano.tensor.advanced_inc_subtensor1(out, 1, x)
else: else:
out = theano.tensor.zeros([max_value], dtype=weights.dtype) out = theano.tensor.zeros([max_value], dtype=weights.dtype)
out = theano.tensor.inc_subtensor(out[x], weights) out = theano.tensor.advanced_inc_subtensor1(out, weights, x)
return out return out
......
...@@ -2069,7 +2069,7 @@ def local_useless_elemwise(node): ...@@ -2069,7 +2069,7 @@ def local_useless_elemwise(node):
return [T.zeros_like(node.inputs[1], dtype=dtype, return [T.zeros_like(node.inputs[1], dtype=dtype,
opt=True)] opt=True)]
else: else:
return [node.inputs[1]] return [node.inputs[1].astype(node.outputs[0].dtype)]
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True) const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
...@@ -2078,7 +2078,7 @@ def local_useless_elemwise(node): ...@@ -2078,7 +2078,7 @@ def local_useless_elemwise(node):
return [T.zeros_like(node.inputs[0], dtype=dtype, return [T.zeros_like(node.inputs[0], dtype=dtype,
opt=True)] opt=True)]
else: else:
return [node.inputs[0]] return [node.inputs[0].astype(node.outputs[0].dtype)]
elif (isinstance(node.op.scalar_op, scalar.OR) and elif (isinstance(node.op.scalar_op, scalar.OR) and
len(node.inputs) == 2): len(node.inputs) == 2):
...@@ -2087,7 +2087,7 @@ def local_useless_elemwise(node): ...@@ -2087,7 +2087,7 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[0], only_process_constants=True) const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[1]] return [node.inputs[1].astype(node.outputs[0].dtype)]
else: else:
return [T.ones_like(node.inputs[1], dtype=dtype, return [T.ones_like(node.inputs[1], dtype=dtype,
opt=True)] opt=True)]
...@@ -2096,7 +2096,7 @@ def local_useless_elemwise(node): ...@@ -2096,7 +2096,7 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True) const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[0]] return [node.inputs[0].astype(node.outputs[0].dtype)]
else: else:
return [T.ones_like(node.inputs[0], dtype=dtype, return [T.ones_like(node.inputs[0], dtype=dtype,
opt=True)] opt=True)]
......
...@@ -4996,7 +4996,7 @@ class T_scalarfromtensor(unittest.TestCase): ...@@ -4996,7 +4996,7 @@ class T_scalarfromtensor(unittest.TestCase):
self.assertTrue(v == 56, v) self.assertTrue(v == 56, v)
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
self.assertTrue(isinstance(v, numpy.int8)) self.assertTrue(isinstance(v, numpy.int16))
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
self.assertTrue(isinstance( self.assertTrue(isinstance(
v, getattr(numpy, str(numpy.asarray(56).dtype)))) v, getattr(numpy, str(numpy.asarray(56).dtype))))
...@@ -7047,7 +7047,7 @@ class T_get_scalar_constant_value(unittest.TestCase): ...@@ -7047,7 +7047,7 @@ class T_get_scalar_constant_value(unittest.TestCase):
assert get_scalar_constant_value(mv[0]) == 1 assert get_scalar_constant_value(mv[0]) == 1
assert get_scalar_constant_value(mv[1]) == 2 assert get_scalar_constant_value(mv[1]) == 2
assert get_scalar_constant_value(mv[2]) == 3 assert get_scalar_constant_value(mv[2]) == 3
assert get_scalar_constant_value(mv[numpy.int8(0)]) == 1 assert get_scalar_constant_value(mv[numpy.int32(0)]) == 1
assert get_scalar_constant_value(mv[numpy.int64(1)]) == 2 assert get_scalar_constant_value(mv[numpy.int64(1)]) == 2
assert get_scalar_constant_value(mv[numpy.uint(2)]) == 3 assert get_scalar_constant_value(mv[numpy.uint(2)]) == 3
t = theano.scalar.Scalar('int64') t = theano.scalar.Scalar('int64')
......
...@@ -3634,33 +3634,39 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3634,33 +3634,39 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
x = T.scalar('x', dtype='int8') x = T.scalar('x', dtype='int8')
f = theano.function([x], T.and_(x, 0), mode=mode) for zero, one in [(numpy.int8(0), numpy.int8(1)), (0, 1)]:
self.assert_eqs_const(f, 0) f = theano.function([x], T.and_(x, zero), mode=mode)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(0, x), mode=mode) f = theano.function([x], T.and_(zero, x), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(x, 1), mode=mode) f = theano.function([x], T.and_(x, one), mode=mode)
self.assert_identity(f) if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
f = theano.function([x], T.and_(1, x), mode=mode) f = theano.function([x], T.and_(one, x), mode=mode)
self.assert_identity(f) if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
def test_or(self): def test_or(self):
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8') x = T.scalar('x', dtype='int8')
f = theano.function([x], T.or_(x, 1), mode=mode) for zero, one in [(numpy.int8(0), numpy.int8(1)), (0, 1)]:
self.assert_eqs_const(f, 1) f = theano.function([x], T.or_(x, one), mode=mode)
self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(1, x), mode=mode) f = theano.function([x], T.or_(one, x), mode=mode)
self.assert_eqs_const(f, 1) self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(x, 0), mode=mode) f = theano.function([x], T.or_(x, zero), mode=mode)
self.assert_identity(f) if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
f = theano.function([x], T.or_(0, x), mode=mode) f = theano.function([x], T.or_(zero, x), mode=mode)
self.assert_identity(f) if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
def test_xor(self): def test_xor(self):
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
......
...@@ -418,7 +418,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin): ...@@ -418,7 +418,7 @@ class T_subtensor(unittest.TestCase, utt.TestOptimizationMixin):
subi = 0 subi = 0
data = numpy.asarray(rand(2, 3), dtype=self.dtype) data = numpy.asarray(rand(2, 3), dtype=self.dtype)
n = self.shared(data) n = self.shared(data)
z = scal.constant(subi) z = scal.constant(subi).astype('int32')
t = n[z:, z] t = n[z:, z]
gn = theano.tensor.grad(theano.tensor.sum(theano.tensor.exp(t)), n) gn = theano.tensor.grad(theano.tensor.sum(theano.tensor.exp(t)), n)
......
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import collections
import copy import copy
import traceback as tb import traceback as tb
import warnings import warnings
...@@ -466,6 +467,27 @@ class _tensor_py_operators(object): ...@@ -466,6 +467,27 @@ class _tensor_py_operators(object):
# SLICING/INDEXING # SLICING/INDEXING
def __getitem__(self, args): def __getitem__(self, args):
def check_bool(args_el):
try:
if isinstance(args_el, (numpy.bool_, bool)) or \
args_el.dtype == 'int8' or args_el.dtype == 'uint8':
raise TypeError(('TensorType does not support boolean '
'mask for indexing such as tensor[x==0]. '
'Instead you can use non_zeros() such as '
'tensor[(x == 0).nonzeros()]. '
'If you are indexing on purpose with an '
'int8, please cast it to int16.'))
except AttributeError:
pass
if not isinstance(args_el, theano.tensor.Variable) and \
isinstance(args_el, collections.Iterable):
for el in args_el:
check_bool(el)
check_bool(args)
if (isinstance(args, list) and if (isinstance(args, list) and
any([isinstance(a, slice) for a in args])): any([isinstance(a, slice) for a in args])):
pass pass
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论