提交 ea1bd6b9 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a test for the nan case and flake8

上级 9452e144
import numpy
from theano import config from theano import config
from theano.tensor.opt import in2out from theano.tensor.opt import in2out
...@@ -8,7 +6,6 @@ from theano.tensor.blas import blas_optdb, optdb, local_optimizer ...@@ -8,7 +6,6 @@ from theano.tensor.blas import blas_optdb, optdb, local_optimizer
from theano.tensor.blas import Ger, ger, ger_destructive from theano.tensor.blas import Ger, ger, ger_destructive
from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace from theano.tensor.blas import Gemv, gemv_inplace, gemv_no_inplace
from theano.tensor import basic as T from theano.tensor import basic as T
import theano.compile
class BaseBLAS(object): class BaseBLAS(object):
......
...@@ -130,6 +130,16 @@ class TestCGemv(TestCase, TestOptimizationMixin): ...@@ -130,6 +130,16 @@ class TestCGemv(TestCase, TestOptimizationMixin):
# scalar # scalar
self.a = tensor.tensor(dtype=dtype, broadcastable=()) self.a = tensor.tensor(dtype=dtype, broadcastable=())
def test_nan_beta_0(self):
f = theano.function([self.A, self.x, self.y, self.a],
self.a*self.y + theano.dot(self.A, self.x),
mode=self.mode)
Aval = numpy.ones((3, 1), dtype=self.dtype)
xval = numpy.ones((1,), dtype=self.dtype)
yval = float('NaN') * numpy.ones((3,), dtype=self.dtype)
zval = f(Aval, xval, yval, 0)
assert not numpy.isnan(zval).any()
def test_optimizations_vm(self): def test_optimizations_vm(self):
''' Test vector dot matrix ''' ''' Test vector dot matrix '''
f = theano.function([self.x, self.A], f = theano.function([self.x, self.A],
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论