提交 f08a0528 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add test that we can reraise and extend the error msg of BadOptimization

上级 92b60da5
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
from nose.plugins.skip import SkipTest import sys
import unittest import unittest
from nose.plugins.skip import SkipTest
from nose.tools import assert_raises
import numpy as np import numpy as np
from six import reraise
from theano import config from theano import config
from theano import gof from theano import gof
from theano.gof.link import raise_with_op
import theano import theano
import theano.tensor
from theano.compat import exc_message from theano.compat import exc_message
from theano.compile import debugmode from theano.compile import debugmode
import theano.compile import theano.tensor
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
...@@ -255,22 +258,53 @@ def test_badoptimization_opt_err(): ...@@ -255,22 +258,53 @@ def test_badoptimization_opt_err():
inputs[-1])) inputs[-1]))
return [node.op(*inputs)] return [node.op(*inputs)]
return False return False
@gof.local_optimizer([theano.tensor.add])
def insert_bad_dtype(node):
if node.op == theano.tensor.add:
inputs = list(node.inputs)
if inputs[-1].owner is None:
return [node.outputs[0].astype('float32')]
return False
edb = gof.EquilibriumDB() edb = gof.EquilibriumDB()
edb.register('insert_bigger_b_add', insert_bigger_b_add, 'all') edb.register('insert_bigger_b_add', insert_bigger_b_add, 'all')
opt = edb.query('+all') opt = edb.query('+all')
edb2 = gof.EquilibriumDB()
edb2.register('insert_bad_dtype', insert_bad_dtype, 'all')
opt2 = edb2.query('+all')
a = theano.tensor.dvector() a = theano.tensor.dvector()
b = theano.tensor.dvector() b = theano.tensor.dvector()
f = theano.function([a, b], a + b, f = theano.function([a, b], a + b,
mode=debugmode.DebugMode(optimizer=opt)) mode=debugmode.DebugMode(optimizer=opt))
try: try:
f([1.0, 2.0, 3.0], [2, 3, 4],) f([1.0, 2.0, 3.0], [2, 3, 4],)
except Exception as e: except ValueError as e:
assert 'insert_bigger_b_add' in exc_message(e) assert 'insert_bigger_b_add' in exc_message(e)
return # TEST PASS else:
assert False
# Test that opt that do an illegal change still get the error from gof.
try:
with theano.configparser.change_flags(on_opt_error='raise'):
f2 = theano.function([a, b], a + b,
mode=debugmode.DebugMode(optimizer=opt2,
stability_patience=1))
f2([1.0, 2.0, 3.0], [2, 3, 4],)
except theano.gof.toolbox.BadOptimization as e:
assert 'insert_bad_dtype' in str(e)
# Test that we can reraise the error with an extended message
try:
new_e = e.__class__("TTT"+str(e))
exc_type, exc_value, exc_trace = sys.exc_info()
exc_value = new_e
reraise(e.__class__, exc_value, exc_trace)
except theano.gof.toolbox.BadOptimization as e:
pass
else:
assert False
else:
assert False assert False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论