提交 4cfc68f2 authored 作者: nouiz's avatar nouiz

Merge pull request #415 from delallea/infinite_loop_in_canonizer

Fixed infinite canonizer loop with NaN constants
......@@ -2761,19 +2761,29 @@ class Canonizer(gof.LocalOptimizer):
# Wrapping ct in a Constant with the right dtype
ct = [T.constant(c, dtype=out_type.dtype) for c in ct]
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct and\
N.all([c.data for c in ct] == self.get_constant(orig_num[0])):
# this is an important trick :( if it so happens that:
# * there's exactly one constant on the numerator and none on
# the denominator
# * it's not the neutral element (ct is an empty list in that case)
# * the constant is the same as the first argument in the numerator
# Then we return very exactly the original num/denum
# If we don't do that the optimizer will just loop
# infinitely because it will not catch on that there are
# no changes to be made and everytime it will want to
# replace something by the same thing...
return orig_num, orig_denum
if orig_num and len(numct) == 1 and len(denumct) == 0 and ct:
# In that case we should only have one constant in `ct`.
assert len(ct) == 1
first_num_ct = self.get_constant(orig_num[0])
if first_num_ct is not None and ct[0].type.values_eq(ct[0].data,
first_num_ct):
# This is an important trick :( if it so happens that:
# * there's exactly one constant on the numerator and none on
# the denominator
# * it's not the neutral element (ct is an empty list in that
# case)
# * the constant is the same as the first argument in the
# numerator (we only check the first argument because the
# canonizer puts the computed constants first)
# -> then we return very exactly the original num/denum.
# If we don't do that the optimizer will just loop
# infinitely because it will not catch on that there are
# no changes to be made and everytime it will want to
# replace something by the same thing...
# Note that it is important to use `values_eq` instead of
# the == operator, to handle NaN values correctly.
return orig_num, orig_denum
return ct + num, denum
def transform(self, node):
......
## PENDING REWRITE OF tensor_opt.py
import copy
import logging
import StringIO
import time
import unittest
......@@ -77,10 +79,10 @@ def optimize(g, level='fast_run'):
return g
def inputs(xbc = (0, 0), ybc = (0, 0), zbc = (0, 0)):
x = TensorType(broadcastable = xbc, dtype = 'float64')('x')
y = TensorType(broadcastable = ybc, dtype = 'float64')('y')
z = TensorType(broadcastable = zbc, dtype = 'float64')('z')
def inputs(xbc=(0, 0), ybc=(0, 0), zbc=(0, 0)):
x = TensorType(broadcastable=xbc, dtype='float64')('x')
y = TensorType(broadcastable=ybc, dtype='float64')('y')
z = TensorType(broadcastable=zbc, dtype='float64')('z')
return x, y, z
......@@ -97,7 +99,9 @@ class test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs()
e = ds(ds(x, (1, 'x', 0)), (2, 0, 'x', 1))
g = Env([x], [e])
self.assertTrue(str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]", str(g))
self.assertTrue(
str(g) == "[DimShuffle{2,0,x,1}(DimShuffle{1,x,0}(x))]",
str(g))
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[DimShuffle{0,1,x,x}(x)]", str(g))
......@@ -105,12 +109,15 @@ class test_dimshuffle_lift(unittest.TestCase):
x, y, z = inputs()
e = ds(ds(ds(x, (0, 'x', 1)), (2, 0, 'x', 1)), (1, 0))
g = Env([x], [e])
self.assertTrue(str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}(DimShuffle{0,x,1}(x)))]", str(g))
self.assertTrue(
str(g) == "[DimShuffle{1,0}(DimShuffle{2,0,x,1}"
"(DimShuffle{0,x,1}(x)))]",
str(g))
dimshuffle_lift.optimize(g)
self.assertTrue(str(g) == "[x]", str(g))
def test_lift(self):
x, y, z = inputs([False]*1, [False]*2, [False]*3)
x, y, z = inputs([False] * 1, [False] * 2, [False] * 3)
e = x + y + z
g = Env([x, y, z], [e])
self.assertTrue(str(g) == ("[Elemwise{add,no_inplace}("
......@@ -672,25 +679,48 @@ class test_canonize(unittest.TestCase):
"""
raise SkipTest("Not implemented")
def test_canonicalize_nan(self):
"""
Regression test for bug in canonicalization of NaN values.
This bug caused an infinite loop which was caught by the equilibrium
optimizer, resulting in an error log message.
"""
sio = StringIO.StringIO()
handler = logging.StreamHandler(sio)
handler.setLevel(logging.ERROR)
logging.getLogger('theano.gof.opt').addHandler(handler)
try:
x = vector()
f = theano.function([x], x + numpy.nan)
finally:
logging.getLogger('theano.gof.opt').removeHandler(handler)
# Ideally this test would only catch the maxed out equilibrium
# optimizer error message, but to be safe in case this message
# is modified in the future, we assert that there is no error
# at all.
assert not sio.getvalue()
def test_local_merge_abs():
x,y,z = T.matrices('xyz')
x_val = numpy.random.rand(5,5).astype(config.floatX)
y_val = numpy.random.rand(5,5).astype(config.floatX)
z_val = numpy.random.rand(5,5).astype(config.floatX)
x, y, z = T.matrices('xyz')
x_val = numpy.random.rand(5, 5).astype(config.floatX)
y_val = numpy.random.rand(5, 5).astype(config.floatX)
z_val = numpy.random.rand(5, 5).astype(config.floatX)
mode = theano.config.mode
if mode == "FAST_COMPILE":
mode = "FAST_RUN"
mode = theano.compile.mode.get_mode(mode).excluding("local_elemwise_fusion")
mode = theano.compile.mode.get_mode(mode).excluding(
"local_elemwise_fusion")
f = theano.function([x,y,z],(abs(y*z*-2)), mode=mode)
f(x_val,y_val,z_val)
f = theano.function([x, y, z], (abs(y * z * -2)), mode=mode)
f(x_val, y_val, z_val)
theano.printing.debugprint(f)
assert isinstance(f.maker.env.toposort()[1].op.scalar_op, scal.Abs)
assert len(f.maker.env.toposort())==2
assert len(f.maker.env.toposort()) == 2
f = theano.function([x,y,z],abs(x/y), mode=mode)
f(x_val,y_val,z_val)
f = theano.function([x, y, z],abs(x / y), mode=mode)
f(x_val, y_val, z_val)
theano.printing.debugprint(f)
assert isinstance(f.maker.env.toposort()[1].op.scalar_op, scal.Abs)
assert len(f.maker.env.toposort())==2
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论