提交 0babc678 authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4168 from abergeron/fix_cgemv_nan

Fix nan handling in output buffer for CGemv
差异被折叠。
...@@ -294,6 +294,11 @@ class TestCpuConv2d(BaseTestConv2d): ...@@ -294,6 +294,11 @@ class TestCpuConv2d(BaseTestConv2d):
def setUp(self): def setUp(self):
super(TestCpuConv2d, self).setUp() super(TestCpuConv2d, self).setUp()
self.mode = theano.compile.mode.get_default_mode().excluding('conv_gemm') self.mode = theano.compile.mode.get_default_mode().excluding('conv_gemm')
self.opt_err = theano.config.on_opt_error
theano.config.on_opt_error = 'ignore'
def tearDown(self):
theano.config.on_opt_error = self.opt_err
def tcase(self, i, f, s, b, flip, provide_shape): def tcase(self, i, f, s, b, flip, provide_shape):
mode = self.mode mode = self.mode
......
...@@ -130,6 +130,16 @@ class TestCGemv(TestCase, TestOptimizationMixin): ...@@ -130,6 +130,16 @@ class TestCGemv(TestCase, TestOptimizationMixin):
# scalar # scalar
self.a = tensor.tensor(dtype=dtype, broadcastable=()) self.a = tensor.tensor(dtype=dtype, broadcastable=())
def test_nan_beta_0(self):
f = theano.function([self.A, self.x, self.y, self.a],
self.a*self.y + theano.dot(self.A, self.x),
mode=self.mode)
Aval = numpy.ones((3, 1), dtype=self.dtype)
xval = numpy.ones((1,), dtype=self.dtype)
yval = float('NaN') * numpy.ones((3,), dtype=self.dtype)
zval = f(Aval, xval, yval, 0)
assert not numpy.isnan(zval).any()
def test_optimizations_vm(self): def test_optimizations_vm(self):
''' Test vector dot matrix ''' ''' Test vector dot matrix '''
f = theano.function([self.x, self.A], f = theano.function([self.x, self.A],
...@@ -140,7 +150,7 @@ class TestCGemv(TestCase, TestOptimizationMixin): ...@@ -140,7 +150,7 @@ class TestCGemv(TestCase, TestOptimizationMixin):
self.assertFunctionContains0(f, tensor.dot) self.assertFunctionContains0(f, tensor.dot)
self.assertFunctionContains1( self.assertFunctionContains1(
f, f,
CGemv(inplace=True, force_init_beta=True) CGemv(inplace=True)
) )
# Assert they produce the same output # Assert they produce the same output
...@@ -161,7 +171,7 @@ class TestCGemv(TestCase, TestOptimizationMixin): ...@@ -161,7 +171,7 @@ class TestCGemv(TestCase, TestOptimizationMixin):
self.assertFunctionContains0(f, tensor.dot) self.assertFunctionContains0(f, tensor.dot)
self.assertFunctionContains1( self.assertFunctionContains1(
f, f,
CGemv(inplace=True, force_init_beta=True) CGemv(inplace=True)
) )
# Assert they produce the same output # Assert they produce the same output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论