提交 725a4be0 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #1354 from nouiz/more_stable_test

make sparse Arctanh gradient test more stable.
...@@ -2208,7 +2208,7 @@ class AddSSDataTester(utt.InferShapeTester): ...@@ -2208,7 +2208,7 @@ class AddSSDataTester(utt.InferShapeTester):
def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
grad_test=True, name=None): grad_test=True, name=None, gap_grad=None):
"""Return the appropriate test class for the elemwise on sparse. """Return the appropriate test class for the elemwise on sparse.
:param op: Op to test. :param op: Op to test.
...@@ -2226,6 +2226,8 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, ...@@ -2226,6 +2226,8 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
dtypes. dtypes.
:param grad_test: True for testing the grad. False will :param grad_test: True for testing the grad. False will
skip this test. skip this test.
:param gap_grad: If None, we reuse gap. Otherwise it is the same as gap
but for testing the gradiant of the op.
:return: The class that perform the tests, not an instance :return: The class that perform the tests, not an instance
of the class. of the class.
...@@ -2241,6 +2243,10 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, ...@@ -2241,6 +2243,10 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
self.op = op self.op = op
self.expected_f = expected_f self.expected_f = expected_f
self.gap = gap self.gap = gap
if gap_grad is not None:
self.gap_grad = gap_grad
else:
self.gap_grad = gap
# Ensure the test's name is correct. # Ensure the test's name is correct.
assert eval(self.__class__.__name__) is self.__class__ assert eval(self.__class__.__name__) is self.__class__
...@@ -2350,7 +2356,8 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None, ...@@ -2350,7 +2356,8 @@ def elemwise_checker(op, expected_f, gap=None, test_dtypes=None,
variable, data = sparse_random_inputs( variable, data = sparse_random_inputs(
format, format,
shape=(4, 7), shape=(4, 7),
out_dtype=dtype) out_dtype=dtype,
gap=self.gap_grad)
verify_grad_sparse(self.op, verify_grad_sparse(self.op,
data, data,
...@@ -2465,7 +2472,8 @@ TanhTester = elemwise_checker( ...@@ -2465,7 +2472,8 @@ TanhTester = elemwise_checker(
ArctanhTester = elemwise_checker( ArctanhTester = elemwise_checker(
sparse.arctanh, sparse.arctanh,
numpy.arctanh, numpy.arctanh,
gap=(-0.9, 1)) gap=(-0.9, 1),
gap_grad=(-0.9, 0.95))
RintTester = elemwise_checker( RintTester = elemwise_checker(
sparse.rint, sparse.rint,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论