提交 94d2c024 authored 作者: Sina Honari's avatar Sina Honari

fixing errors related to int8 type

上级 58eecf0e
...@@ -1121,8 +1121,7 @@ Theano indexing with a "mask" (incorrect approach): ...@@ -1121,8 +1121,7 @@ Theano indexing with a "mask" (incorrect approach):
.. doctest:: indexing .. doctest:: indexing
>>> t = theano.tensor.arange(9).reshape((3,3)) >>> t = theano.tensor.arange(9).reshape((3,3))
>>> t[t > 4].eval() # an array with shape (3, 3, 3) >>> t[t > 4].eval() # an array with shape (3, 3, 3) # doctest: +ELLIPSIS
Traceback (most recent call last): Traceback (most recent call last):
... ...
TypeError: TensorType does not support boolean mask for indexing such as tensor[x==0]. If you are indexing on purpose with an int8, please cast it to int32 TypeError: TensorType does not support boolean mask for indexing such as tensor[x==0]. If you are indexing on purpose with an int8, please cast it to int32
......
...@@ -2053,7 +2053,7 @@ def local_useless_elemwise(node): ...@@ -2053,7 +2053,7 @@ def local_useless_elemwise(node):
if const_val == 0: if const_val == 0:
return zeros_like(node, 1) return zeros_like(node, 1)
else: else:
return [node.inputs[1]] return [node.inputs[1].astype(node.outputs[0].dtype)]
if isinstance(node.inputs[1], T.TensorConstant): if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True) const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
...@@ -2061,7 +2061,7 @@ def local_useless_elemwise(node): ...@@ -2061,7 +2061,7 @@ def local_useless_elemwise(node):
if const_val == 0: if const_val == 0:
return zeros_like(node, 0) return zeros_like(node, 0)
else: else:
return [node.inputs[0]] return [node.inputs[0].astype(node.outputs[0].dtype)]
elif (isinstance(node.op.scalar_op, scalar.OR) and elif (isinstance(node.op.scalar_op, scalar.OR) and
len(node.inputs) == 2): len(node.inputs) == 2):
...@@ -2070,7 +2070,7 @@ def local_useless_elemwise(node): ...@@ -2070,7 +2070,7 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[0], only_process_constants=True) const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[1]] return [node.inputs[1].astype(node.outputs[0].dtype)]
else: else:
return ones_like(node, 1) return ones_like(node, 1)
...@@ -2078,7 +2078,7 @@ def local_useless_elemwise(node): ...@@ -2078,7 +2078,7 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True) const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
if not isinstance(const_val, Variable): if not isinstance(const_val, Variable):
if const_val == 0: if const_val == 0:
return [node.inputs[0]] return [node.inputs[0].astype(node.outputs[0].dtype)]
else: else:
return ones_like(node, 0) return ones_like(node, 0)
......
...@@ -4978,7 +4978,7 @@ class T_scalarfromtensor(unittest.TestCase): ...@@ -4978,7 +4978,7 @@ class T_scalarfromtensor(unittest.TestCase):
self.assertTrue(v == 56, v) self.assertTrue(v == 56, v)
if config.cast_policy == 'custom': if config.cast_policy == 'custom':
self.assertTrue(isinstance(v, numpy.int8)) self.assertTrue(isinstance(v, numpy.int16))
elif config.cast_policy in ('numpy', 'numpy+floatX'): elif config.cast_policy in ('numpy', 'numpy+floatX'):
self.assertTrue(isinstance( self.assertTrue(isinstance(
v, getattr(numpy, str(numpy.asarray(56).dtype)))) v, getattr(numpy, str(numpy.asarray(56).dtype))))
......
...@@ -3610,33 +3610,39 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase): ...@@ -3610,33 +3610,39 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
x = T.scalar('x', dtype='int8') x = T.scalar('x', dtype='int8')
f = theano.function([x], T.and_(x, 0), mode=mode) for zero, one in [(numpy.int8(0), numpy.int8(1)), (0, 1)]:
self.assert_eqs_const(f, 0) f = theano.function([x], T.and_(x, zero), mode=mode)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(0, x), mode=mode) f = theano.function([x], T.and_(zero, x), mode=mode)
self.assert_eqs_const(f, 0) self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(x, 1), mode=mode) f = theano.function([x], T.and_(x, one), mode=mode)
self.assert_identity(f) if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
f = theano.function([x], T.and_(1, x), mode=mode) f = theano.function([x], T.and_(one, x), mode=mode)
self.assert_identity(f) if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
def test_or(self): def test_or(self):
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8') x = T.scalar('x', dtype='int8')
f = theano.function([x], T.or_(x, 1), mode=mode) for zero, one in [(numpy.int8(0), numpy.int8(1)), (0, 1)]:
self.assert_eqs_const(f, 1) f = theano.function([x], T.or_(x, one), mode=mode)
self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(1, x), mode=mode) f = theano.function([x], T.or_(one, x), mode=mode)
self.assert_eqs_const(f, 1) self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(x, 0), mode=mode) f = theano.function([x], T.or_(x, zero), mode=mode)
self.assert_identity(f) if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
f = theano.function([x], T.or_(0, x), mode=mode) f = theano.function([x], T.or_(zero, x), mode=mode)
self.assert_identity(f) if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
def test_xor(self): def test_xor(self):
mode = theano.compile.get_default_mode().including('canonicalize') mode = theano.compile.get_default_mode().including('canonicalize')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论