提交 94d2c024 authored 作者: Sina Honari's avatar Sina Honari

fixing errors related to int8 type

上级 58eecf0e
......@@ -1121,8 +1121,7 @@ Theano indexing with a "mask" (incorrect approach):
.. doctest:: indexing
>>> t = theano.tensor.arange(9).reshape((3,3))
>>> t[t > 4].eval() # an array with shape (3, 3, 3)
>>> t[t > 4].eval() # an array with shape (3, 3, 3) # doctest: +ELLIPSIS
Traceback (most recent call last):
...
TypeError: TensorType does not support boolean mask for indexing such as tensor[x==0]. If you are indexing on purpose with an int8, please cast it to int32
......
......@@ -2053,7 +2053,7 @@ def local_useless_elemwise(node):
if const_val == 0:
return zeros_like(node, 1)
else:
return [node.inputs[1]]
return [node.inputs[1].astype(node.outputs[0].dtype)]
if isinstance(node.inputs[1], T.TensorConstant):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
......@@ -2061,7 +2061,7 @@ def local_useless_elemwise(node):
if const_val == 0:
return zeros_like(node, 0)
else:
return [node.inputs[0]]
return [node.inputs[0].astype(node.outputs[0].dtype)]
elif (isinstance(node.op.scalar_op, scalar.OR) and
len(node.inputs) == 2):
......@@ -2070,7 +2070,7 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[0], only_process_constants=True)
if not isinstance(const_val, Variable):
if const_val == 0:
return [node.inputs[1]]
return [node.inputs[1].astype(node.outputs[0].dtype)]
else:
return ones_like(node, 1)
......@@ -2078,7 +2078,7 @@ def local_useless_elemwise(node):
const_val = T.extract_constant(node.inputs[1], only_process_constants=True)
if not isinstance(const_val, Variable):
if const_val == 0:
return [node.inputs[0]]
return [node.inputs[0].astype(node.outputs[0].dtype)]
else:
return ones_like(node, 0)
......
......@@ -4978,7 +4978,7 @@ class T_scalarfromtensor(unittest.TestCase):
self.assertTrue(v == 56, v)
if config.cast_policy == 'custom':
self.assertTrue(isinstance(v, numpy.int8))
self.assertTrue(isinstance(v, numpy.int16))
elif config.cast_policy in ('numpy', 'numpy+floatX'):
self.assertTrue(isinstance(
v, getattr(numpy, str(numpy.asarray(56).dtype))))
......
......@@ -3610,33 +3610,39 @@ class Test_local_useless_elemwise_comparison(unittest.TestCase):
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.and_(x, 0), mode=mode)
self.assert_eqs_const(f, 0)
for zero, one in [(numpy.int8(0), numpy.int8(1)), (0, 1)]:
f = theano.function([x], T.and_(x, zero), mode=mode)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(0, x), mode=mode)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(zero, x), mode=mode)
self.assert_eqs_const(f, 0)
f = theano.function([x], T.and_(x, 1), mode=mode)
self.assert_identity(f)
f = theano.function([x], T.and_(x, one), mode=mode)
if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
f = theano.function([x], T.and_(1, x), mode=mode)
self.assert_identity(f)
f = theano.function([x], T.and_(one, x), mode=mode)
if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
def test_or(self):
mode = theano.compile.get_default_mode().including('canonicalize')
x = T.scalar('x', dtype='int8')
f = theano.function([x], T.or_(x, 1), mode=mode)
self.assert_eqs_const(f, 1)
for zero, one in [(numpy.int8(0), numpy.int8(1)), (0, 1)]:
f = theano.function([x], T.or_(x, one), mode=mode)
self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(1, x), mode=mode)
self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(one, x), mode=mode)
self.assert_eqs_const(f, 1)
f = theano.function([x], T.or_(x, 0), mode=mode)
self.assert_identity(f)
f = theano.function([x], T.or_(x, zero), mode=mode)
if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
f = theano.function([x], T.or_(0, x), mode=mode)
self.assert_identity(f)
f = theano.function([x], T.or_(zero, x), mode=mode)
if f.outputs[0].variable.dtype == x.dtype:
self.assert_identity(f)
def test_xor(self):
mode = theano.compile.get_default_mode().including('canonicalize')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论