提交 6137cd66 authored 作者: Benjamin Scellier's avatar Benjamin Scellier

file theano/compile/tests/test_nanguardmode.py

上级 75a6ec49
......@@ -6,7 +6,7 @@ from __future__ import absolute_import, print_function, division
import logging
from nose.tools import assert_raises
import numpy
import numpy as np
from theano.compile.nanguardmode import NanGuardMode
import theano
......@@ -18,20 +18,20 @@ def test_NanGuardMode():
# intentionally. A working implementation should be able to capture all
# the abnormalties.
x = T.matrix()
w = theano.shared(numpy.random.randn(5, 7).astype(theano.config.floatX))
w = theano.shared(np.random.randn(5, 7).astype(theano.config.floatX))
y = T.dot(x, w)
fun = theano.function(
[x], y,
mode=NanGuardMode(nan_is_error=True, inf_is_error=True)
)
a = numpy.random.randn(3, 5).astype(theano.config.floatX)
infa = numpy.tile(
(numpy.asarray(100.) ** 1000000).astype(theano.config.floatX), (3, 5))
nana = numpy.tile(
numpy.asarray(numpy.nan).astype(theano.config.floatX), (3, 5))
biga = numpy.tile(
numpy.asarray(1e20).astype(theano.config.floatX), (3, 5))
a = np.random.randn(3, 5).astype(theano.config.floatX)
infa = np.tile(
(np.asarray(100.) ** 1000000).astype(theano.config.floatX), (3, 5))
nana = np.tile(
np.asarray(np.nan).astype(theano.config.floatX), (3, 5))
biga = np.tile(
np.asarray(1e20).astype(theano.config.floatX), (3, 5))
fun(a) # normal values
......@@ -46,14 +46,14 @@ def test_NanGuardMode():
_logger.propagate = True
# slices
a = numpy.random.randn(3, 4, 5).astype(theano.config.floatX)
infa = numpy.tile(
(numpy.asarray(100.) ** 1000000).astype(theano.config.floatX),
a = np.random.randn(3, 4, 5).astype(theano.config.floatX)
infa = np.tile(
(np.asarray(100.) ** 1000000).astype(theano.config.floatX),
(3, 4, 5))
nana = numpy.tile(
numpy.asarray(numpy.nan).astype(theano.config.floatX), (3, 4, 5))
biga = numpy.tile(
numpy.asarray(1e20).astype(theano.config.floatX), (3, 4, 5))
nana = np.tile(
np.asarray(np.nan).astype(theano.config.floatX), (3, 4, 5))
biga = np.tile(
np.asarray(1e20).astype(theano.config.floatX), (3, 4, 5))
x = T.tensor3()
y = x[:, T.arange(2), T.arange(2)]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论