提交 2e4eeefc authored 作者: Frederic's avatar Frederic

Make some expm grad test always be in float64 for numerical stability of the test.

上级 0381ea66
...@@ -251,10 +251,11 @@ def test_expm_grad_1(): ...@@ -251,10 +251,11 @@ def test_expm_grad_1():
if not imported_scipy: if not imported_scipy:
raise SkipTest("Scipy needed for the expm op.") raise SkipTest("Scipy needed for the expm op.")
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
A = rng.randn(5, 5).astype(config.floatX) # Always test in float64 for better numerical stability.
A = rng.randn(5, 5)
A = A + A.T A = A + A.T
tensor.verify_grad(expm, [A,], rng=rng) tensor.verify_grad(expm, [A], rng=rng)
def test_expm_grad_2(): def test_expm_grad_2():
...@@ -262,12 +263,13 @@ def test_expm_grad_2(): ...@@ -262,12 +263,13 @@ def test_expm_grad_2():
if not imported_scipy: if not imported_scipy:
raise SkipTest("Scipy needed for the expm op.") raise SkipTest("Scipy needed for the expm op.")
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
A = rng.randn(5, 5).astype(config.floatX) # Always test in float64 for better numerical stability.
w = (rng.randn(5).astype(config.floatX))**2 A = rng.randn(5, 5)
w = rng.randn(5)**2
A = (numpy.diag(w**0.5)).dot(A + A.T).dot(numpy.diag(w**(-0.5))) A = (numpy.diag(w**0.5)).dot(A + A.T).dot(numpy.diag(w**(-0.5)))
assert not numpy.allclose(A, A.T) assert not numpy.allclose(A, A.T)
tensor.verify_grad(expm, [A,], rng=rng) tensor.verify_grad(expm, [A], rng=rng)
def test_expm_grad_3(): def test_expm_grad_3():
...@@ -275,6 +277,7 @@ def test_expm_grad_3(): ...@@ -275,6 +277,7 @@ def test_expm_grad_3():
if not imported_scipy: if not imported_scipy:
raise SkipTest("Scipy needed for the expm op.") raise SkipTest("Scipy needed for the expm op.")
rng = numpy.random.RandomState(utt.fetch_seed()) rng = numpy.random.RandomState(utt.fetch_seed())
A = rng.randn(5, 5).astype(config.floatX) # Always test in float64 for better numerical stability.
A = rng.randn(5, 5)
tensor.verify_grad(expm, [A,], rng=rng) tensor.verify_grad(expm, [A], rng=rng)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论