提交 e9ee91e0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Fix typo in NanGuardMode test

上级 cac3e5b1
...@@ -26,7 +26,7 @@ def test_NanGuardMode(): ...@@ -26,7 +26,7 @@ def test_NanGuardMode():
a = np.random.randn(3, 5).astype(theano.config.floatX) a = np.random.randn(3, 5).astype(theano.config.floatX)
infa = np.tile((np.asarray(100.0) ** 1000000).astype(theano.config.floatX), (3, 5)) infa = np.tile((np.asarray(100.0) ** 1000000).astype(theano.config.floatX), (3, 5))
nana = np.tile(np.asarray(np.nan).astype(theano.config.floatX), (3, 5)) nana = np.tile(np.asarray(np.nan).astype(theano.config.floatX), (3, 5))
# biga = np.tile(np.asarray(1e20).astype(theano.config.floatX), (3, 5)) biga = np.tile(np.asarray(1e20).astype(theano.config.floatX), (3, 5))
fun(a) # normal values fun(a) # normal values
...@@ -39,7 +39,7 @@ def test_NanGuardMode(): ...@@ -39,7 +39,7 @@ def test_NanGuardMode():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
fun(nana) # NANs fun(nana) # NANs
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
fun(infa) # big values fun(biga) # big values
finally: finally:
_logger.propagate = True _logger.propagate = True
...@@ -49,7 +49,7 @@ def test_NanGuardMode(): ...@@ -49,7 +49,7 @@ def test_NanGuardMode():
(np.asarray(100.0) ** 1000000).astype(theano.config.floatX), (3, 4, 5) (np.asarray(100.0) ** 1000000).astype(theano.config.floatX), (3, 4, 5)
) )
nana = np.tile(np.asarray(np.nan).astype(theano.config.floatX), (3, 4, 5)) nana = np.tile(np.asarray(np.nan).astype(theano.config.floatX), (3, 4, 5))
# biga = np.tile(np.asarray(1e20).astype(theano.config.floatX), (3, 4, 5)) biga = np.tile(np.asarray(1e20).astype(theano.config.floatX), (3, 4, 5))
x = T.tensor3() x = T.tensor3()
y = x[:, T.arange(2), T.arange(2), None] y = x[:, T.arange(2), T.arange(2), None]
...@@ -64,6 +64,6 @@ def test_NanGuardMode(): ...@@ -64,6 +64,6 @@ def test_NanGuardMode():
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
fun(nana) # NANs fun(nana) # NANs
with pytest.raises(AssertionError): with pytest.raises(AssertionError):
fun(infa) # big values fun(biga) # big values
finally: finally:
_logger.propagate = True _logger.propagate = True
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论