提交 51c64ce5 authored 作者: Nicolas Bouchard's avatar Nicolas Bouchard

Fix bugs with int8 dtype.

上级 90ca8271
......@@ -2274,6 +2274,8 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
def test_op(self):
for format in sparse.sparse_formats:
for dtype in test_dtypes:
if dtype == 'int8':
continue
variable, data = sparse_random_inputs(
format,
shape=(4, 7),
......@@ -2283,7 +2285,6 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
f = theano.function(variable, self.op(*variable))
tested = f(*data)
data = [m.toarray() for m in data]
expected = self.expected_f(*data)
......@@ -2295,6 +2296,43 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
except AssertionError:
raise AssertionError(self.__name__)
# Test with int8 as dtype
if 'int 8' in test_dtypes:
variable, data = sparse_random_inputs(
format,
shape=(4, 7),
out_dtype='int8',
gap=(0, 5))
f = theano.function(variable, self.op(*variable))
old_value = (tensor.basic.float32_atol,
tensor.basic.float32_rtol,
tensor.basic.float64_atol,
tensor.basic.float64_rtol)
tensor.basic.float32_atol = 1e-4
tensor.basic.float32_rtol = 1e-3
tensor.basic.float64_atol = 1e-3
tensor.basic.float64_rtol = 1e-4
try:
tested = f(*data)
finally:
(tensor.basic.float32_atol,
tensor.basic.float32_rtol,
tensor.basic.float64_atol,
tensor.basic.float64_rtol) = old_value
data = [m.toarray().astype('float32') for m in data]
expected = self.expected_f(*data)
assert tested.format == format
tested = tested.toarray()
try:
assert numpy.allclose(tested, expected, rtol=1e-2)
except AssertionError:
raise AssertionError(self.__name__)
if grad_test:
def test_grad(self):
for format in sparse.sparse_formats:
......@@ -2345,7 +2383,8 @@ StructuredExpTester = elemwise_checker(
StructuredLogTester = elemwise_checker(
sparse.structured_log,
structure_function(numpy.log))
structure_function(numpy.log),
gap=(0.5, 10))
StructuredPowTester = elemwise_checker(
lambda x: sparse.structured_pow(x, 2),
......@@ -2369,7 +2408,8 @@ SinTester = elemwise_checker(
TanTester = elemwise_checker(
sparse.tan,
numpy.tan)
numpy.tan,
gap=(-1, 1))
ArcSinTester = elemwise_checker(
sparse.arcsin,
......@@ -2386,11 +2426,13 @@ SinhTester = elemwise_checker(
ArcSinhTester = elemwise_checker(
sparse.arcsinh,
numpy.arcsinh)
numpy.arcsinh,
gap=(-1, 1))
TanhTester = elemwise_checker(
sparse.tanh,
numpy.tanh)
numpy.tanh,
gap=(-1, 1))
ArcTanhTester = elemwise_checker(
sparse.arctanh,
......@@ -2426,7 +2468,8 @@ FloorTester = elemwise_checker(
Log1pTester = elemwise_checker(
sparse.log1p,
numpy.log1p)
numpy.log1p,
gap=(0.5, 10))
SqrTester = elemwise_checker(
sparse.sqr,
......@@ -2434,7 +2477,8 @@ SqrTester = elemwise_checker(
SqrtTester = elemwise_checker(
sparse.sqrt,
numpy.sqrt)
numpy.sqrt,
gap=(0, 10))
class MulSVTester(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论