提交 7a03e5a3 authored 作者: Frederic Bastien's avatar Frederic Bastien

Fix test in DebugMode, GemmOptimizer can remove nan/inf.

上级 84906dbb
......@@ -152,6 +152,7 @@ from theano.tensor import basic as T
from theano.tensor.blas_headers import blas_header_text
from theano.tensor.blas_headers import blas_header_version
from theano.tensor.opt import in2out, local_dimshuffle_lift
from theano.tensor.type import values_eq_approx_remove_inf_nan
_logger = logging.getLogger('theano.tensor.blas')
......@@ -1465,6 +1466,7 @@ class GemmOptimizer(Optimizer):
if new_outputs:
new_outputs, old_dot22 = new_outputs
assert len(new_outputs) == len(node.outputs)
new_outputs[0].tag.values_eq_approx = values_eq_approx_remove_inf_nan
try:
fgraph.replace_all_validate_remove(
list(zip(node.outputs, new_outputs)),
......
......@@ -132,9 +132,11 @@ class TestCGemv(TestCase, TestOptimizationMixin):
self.a = tensor.tensor(dtype=dtype, broadcastable=())
def test_nan_beta_0(self):
mode = self.mode.including()
mode.check_isfinite = False
f = theano.function([self.A, self.x, self.y, self.a],
self.a*self.y + theano.dot(self.A, self.x),
mode=self.mode)
mode=mode)
Aval = numpy.ones((3, 1), dtype=self.dtype)
xval = numpy.ones((1,), dtype=self.dtype)
yval = float('NaN') * numpy.ones((3,), dtype=self.dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论