提交 51c64ce5 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Fix bugs with int8 dtype.

上级 90ca8271
...@@ -2274,6 +2274,8 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, ...@@ -2274,6 +2274,8 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
def test_op(self): def test_op(self):
for format in sparse.sparse_formats: for format in sparse.sparse_formats:
for dtype in test_dtypes: for dtype in test_dtypes:
if dtype == 'int8':
continue
variable, data = sparse_random_inputs( variable, data = sparse_random_inputs(
format, format,
shape=(4, 7), shape=(4, 7),
...@@ -2283,7 +2285,6 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, ...@@ -2283,7 +2285,6 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
f = theano.function(variable, self.op(*variable)) f = theano.function(variable, self.op(*variable))
tested = f(*data) tested = f(*data)
data = [m.toarray() for m in data] data = [m.toarray() for m in data]
expected = self.expected_f(*data) expected = self.expected_f(*data)
...@@ -2295,6 +2296,43 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, ...@@ -2295,6 +2296,43 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
except AssertionError: except AssertionError:
raise AssertionError(self.__name__) raise AssertionError(self.__name__)
# Test with int8 as dtype
if 'int 8' in test_dtypes:
variable, data = sparse_random_inputs(
format,
shape=(4, 7),
out_dtype='int8',
gap=(0, 5))
f = theano.function(variable, self.op(*variable))
old_value = (tensor.basic.float32_atol,
tensor.basic.float32_rtol,
tensor.basic.float64_atol,
tensor.basic.float64_rtol)
tensor.basic.float32_atol = 1e-4
tensor.basic.float32_rtol = 1e-3
tensor.basic.float64_atol = 1e-3
tensor.basic.float64_rtol = 1e-4
try:
tested = f(*data)
finally:
(tensor.basic.float32_atol,
tensor.basic.float32_rtol,
tensor.basic.float64_atol,
tensor.basic.float64_rtol) = old_value
data = [m.toarray().astype('float32') for m in data]
expected = self.expected_f(*data)
assert tested.format == format
tested = tested.toarray()
try:
assert numpy.allclose(tested, expected, rtol=1e-2)
except AssertionError:
raise AssertionError(self.__name__)
if grad_test: if grad_test:
def test_grad(self): def test_grad(self):
for format in sparse.sparse_formats: for format in sparse.sparse_formats:
...@@ -2345,7 +2383,8 @@ StructuredExpTester = elemwise_checker( ...@@ -2345,7 +2383,8 @@ StructuredExpTester = elemwise_checker(
StructuredLogTester = elemwise_checker( StructuredLogTester = elemwise_checker(
sparse.structured_log, sparse.structured_log,
structure_function(numpy.log)) structure_function(numpy.log),
gap=(0.5, 10))
StructuredPowTester = elemwise_checker( StructuredPowTester = elemwise_checker(
lambda x: sparse.structured_pow(x, 2), lambda x: sparse.structured_pow(x, 2),
...@@ -2369,7 +2408,8 @@ SinTester = elemwise_checker( ...@@ -2369,7 +2408,8 @@ SinTester = elemwise_checker(
TanTester = elemwise_checker( TanTester = elemwise_checker(
sparse.tan, sparse.tan,
numpy.tan) numpy.tan,
gap=(-1, 1))
ArcSinTester = elemwise_checker( ArcSinTester = elemwise_checker(
sparse.arcsin, sparse.arcsin,
...@@ -2386,11 +2426,13 @@ SinhTester = elemwise_checker( ...@@ -2386,11 +2426,13 @@ SinhTester = elemwise_checker(
ArcSinhTester = elemwise_checker( ArcSinhTester = elemwise_checker(
sparse.arcsinh, sparse.arcsinh,
numpy.arcsinh) numpy.arcsinh,
gap=(-1, 1))
TanhTester = elemwise_checker( TanhTester = elemwise_checker(
sparse.tanh, sparse.tanh,
numpy.tanh) numpy.tanh,
gap=(-1, 1))
ArcTanhTester = elemwise_checker( ArcTanhTester = elemwise_checker(
sparse.arctanh, sparse.arctanh,
...@@ -2426,7 +2468,8 @@ FloorTester = elemwise_checker( ...@@ -2426,7 +2468,8 @@ FloorTester = elemwise_checker(
Log1pTester = elemwise_checker( Log1pTester = elemwise_checker(
sparse.log1p, sparse.log1p,
numpy.log1p) numpy.log1p,
gap=(0.5, 10))
SqrTester = elemwise_checker( SqrTester = elemwise_checker(
sparse.sqr, sparse.sqr,
...@@ -2434,7 +2477,8 @@ SqrTester = elemwise_checker( ...@@ -2434,7 +2477,8 @@ SqrTester = elemwise_checker(
SqrtTester = elemwise_checker( SqrtTester = elemwise_checker(
sparse.sqrt, sparse.sqrt,
numpy.sqrt) numpy.sqrt,
gap=(0, 10))
class MulSVTester(unittest.TestCase): class MulSVTester(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论