提交 4cfc68f2 authored 作者: nouiz's avatar nouiz

Merge pull request #415 from delallea/infinite_loop_in_canonizer

Fixed infinite canonizer loop with NaN constants
...@@ -2761,19 +2761,29 @@ class Canonizer(gof.LocalOptimizer): ...@@ -2761,19 +2761,29 @@ class Canonizer(gof.LocalOptimizer):
# Wrapping ct in a Constant with the right dtype # Wrapping ct in a Constant with the right dtype
ct = [T.constant(c, dtype=out_type.dtype) for c in ct] ct = [T.constant(c, dtype=out_type.dtype) for c in ct]
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and\ if orig_num and len(numct) == 1 and len(denumct) == 0 and ct:
N.all([c.data for c in ct] == self.get_constant(orig_num[0])): # In that case we should only have one constant in `ct`.
# this is an important trick :( if it so happens that: assert len(ct) == 1
# * there's exactly one constant on the numerator and none on first_num_ct = self.get_constant(orig_num[0])
# the denominator if first_num_ct is not None and ct[0].type.values_eq(ct[0].data,
# * it's not the neutral element (ct is an empty list in that case) first_num_ct):
# * the constant is the same as the first argument in the numerator # This is an important trick :( if it so happens that:
# Then we return very exactly the original num/denum # * there's exactly one constant on the numerator and none on
# If we don't do that the optimizer will just loop # the denominator
# infinitely because it will not catch on that there are # * it's not the neutral element (ct is an empty list in that
# no changes to be made and everytime it will want to # case)
# replace something by the same thing... # * the constant is the same as the first argument in the
return orig_num, orig_denum # numerator (we only check the first argument because the
# canonizer puts the computed constants first)
# -> then we return very exactly the original num/denum.
# If we don't do that the optimizer will just loop
# infinitely because it will not catch on that there are
# no changes to be made and everytime it will want to
# replace something by the same thing...
# Note that it is important to use `values_eq` instead of
# the == operator, to handle NaN values correctly.
return orig_num, orig_denum
return ct + num, denum return ct + num, denum
def transform(self, node): def transform(self, node):
......
## PENDING REWRITE OF tensor_opt.py ## PENDING REWRITE OF tensor_opt.py
import copy import copy
import logging
import StringIO
import time import time
import unittest import unittest
...@@ -77,10 +79,10 @@ def optimize(g, level='fast_run'): ...@@ -77,10 +79,10 @@ def optimize(g, level='fast_run'):
return g return g
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)): def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
x = TensorType(broadcastable = xbc, dtype = 'float64')('x') x = TensorType(broadcastable=xbc, dtype='float64')('x')
y = TensorType(broadcastable = ybc, dtype = 'float64')('y') y = TensorType(broadcastable=ybc, dtype='float64')('y')
z = TensorType(broadcastable = zbc, dtype = 'float64')('z') z = TensorType(broadcastable=zbc, dtype='float64')('z')
return x, y, z return x, y, z
...@@ -97,7 +99,9 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -97,7 +99,9 @@ class test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1)) e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e]) g = Env([x], [e])
self.assertTrue(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g)) self.assertTrue(
str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]",
str(g))
dimshuffle_lift.optimize(g) dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g)) self.assertTrue(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
...@@ -105,12 +109,15 @@ class test_dimshuffle_lift(unittest.TestCase): ...@@ -105,12 +109,15 @@ class test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs() x, y, z = inputs()
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0)) e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e]) g = Env([x], [e])
self.assertTrue(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g)) self.assertTrue(
str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}"
"(DimShuffle{0,x,1}(x)))]",
str(g))
dimshuffle_lift.optimize(g) dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x]", str(g)) self.assertTrue(str(g) == "[x]", str(g))
def test_lift(self): def test_lift(self):
x, y, z = inputs([False]*1, [False]*2, [False]*3) x, y, z = inputs([False] * 1, [False] * 2, [False] * 3)
e = x + y + z e = x + y + z
g = Env([x, y, z], [e]) g = Env([x, y, z], [e])
self.assertTrue(str(g) == ("[Elemwise{add,no_inplace}(" self.assertTrue(str(g) == ("[Elemwise{add,no_inplace}("
...@@ -672,25 +679,48 @@ class test_canonize(unittest.TestCase): ...@@ -672,25 +679,48 @@ class test_canonize(unittest.TestCase):
""" """
raise SkipTest("Not implemented") raise SkipTest("Not implemented")
def test_canonicalize_nan(self):
"""
Regression test for bug in canonicalization of NaN values.
This bug caused an infinite loop which was caught by the equilibrium
optimizer, resulting in an error log message.
"""
sio = StringIO.StringIO()
handler = logging.StreamHandler(sio)
handler.setLevel(logging.ERROR)
logging.getLogger('theano.gof.opt').addHandler(handler)
try:
x = vector()
f = theano.function([x], x + numpy.nan)
finally:
logging.getLogger('theano.gof.opt').removeHandler(handler)
# Ideally this test would only catch the maxed out equilibrium
# optimizer error message, but to be safe in case this message
# is modified in the future, we assert that there is no error
# at all.
assert not sio.getvalue()
def test_local_merge_abs(): def test_local_merge_abs():
x,y,z = T.matrices('xyz') x, y, z = T.matrices('xyz')
x_val = numpy.random.rand(5,5).astype(config.floatX) x_val = numpy.random.rand(5, 5).astype(config.floatX)
y_val = numpy.random.rand(5,5).astype(config.floatX) y_val = numpy.random.rand(5, 5).astype(config.floatX)
z_val = numpy.random.rand(5,5).astype(config.floatX) z_val = numpy.random.rand(5, 5).astype(config.floatX)
mode = theano.config.mode mode = theano.config.mode
if mode == "FAST_COMPILE": if mode == "FAST_COMPILE":
mode = "FAST_RUN" mode = "FAST_RUN"
mode = theano.compile.mode.get_mode(mode).excluding("local_elemwise_fusion") mode = theano.compile.mode.get_mode(mode).excluding(
"local_elemwise_fusion")
f = theano.function([x,y,z],(abs(y*z*-2)), mode=mode) f = theano.function([x, y, z], (abs(y * z * -2)), mode=mode)
f(x_val,y_val,z_val) f(x_val, y_val, z_val)
theano.printing.debugprint(f) theano.printing.debugprint(f)
assert isinstance(f.maker.env.toposort()[1].op.scalar_op, scal.Abs) assert isinstance(f.maker.env.toposort()[1].op.scalar_op, scal.Abs)
assert len(f.maker.env.toposort())==2 assert len(f.maker.env.toposort()) == 2
f = theano.function([x,y,z],abs(x/y), mode=mode) f = theano.function([x, y, z],abs(x / y), mode=mode)
f(x_val,y_val,z_val) f(x_val, y_val, z_val)
theano.printing.debugprint(f) theano.printing.debugprint(f)
assert isinstance(f.maker.env.toposort()[1].op.scalar_op, scal.Abs) assert isinstance(f.maker.env.toposort()[1].op.scalar_op, scal.Abs)
assert len(f.maker.env.toposort())==2 assert len(f.maker.env.toposort())==2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论