提交 f08a0528 authored 作者: Frederic Bastien's avatar Frederic Bastien

Add test that we can reraise and extend the error msg of BadOptimization

上级 92b60da5
from __future__ import absolute_import, print_function, division
from nose.plugins.skip import SkipTest
import sys
import unittest
from nose.plugins.skip import SkipTest
from nose.tools import assert_raises
import numpy as np
from six import reraise
from theano import config
from theano import gof
from theano.gof.link import raise_with_op
import theano
import theano.tensor
from theano.compat import exc_message
from theano.compile import debugmode
import theano.compile
import theano.tensor
from theano.tests import unittest_tools as utt
......@@ -255,23 +258,54 @@ def test_badoptimization_opt_err():
inputs[-1]))
return [node.op(*inputs)]
return False
@gof.local_optimizer([theano.tensor.add])
def insert_bad_dtype(node):
if node.op == theano.tensor.add:
inputs = list(node.inputs)
if inputs[-1].owner is None:
return [node.outputs[0].astype('float32')]
return False
edb = gof.EquilibriumDB()
edb.register('insert_bigger_b_add', insert_bigger_b_add, 'all')
opt = edb.query('+all')
edb2 = gof.EquilibriumDB()
edb2.register('insert_bad_dtype', insert_bad_dtype, 'all')
opt2 = edb2.query('+all')
a = theano.tensor.dvector()
b = theano.tensor.dvector()
f = theano.function([a, b], a + b,
mode=debugmode.DebugMode(optimizer=opt))
try:
f([1.0, 2.0, 3.0], [2, 3, 4],)
except Exception as e:
except ValueError as e:
assert 'insert_bigger_b_add' in exc_message(e)
return # TEST PASS
else:
assert False
assert False
# Test that opt that do an illegal change still get the error from gof.
try:
with theano.configparser.change_flags(on_opt_error='raise'):
f2 = theano.function([a, b], a + b,
mode=debugmode.DebugMode(optimizer=opt2,
stability_patience=1))
f2([1.0, 2.0, 3.0], [2, 3, 4],)
except theano.gof.toolbox.BadOptimization as e:
assert 'insert_bad_dtype' in str(e)
# Test that we can reraise the error with an extended message
try:
new_e = e.__class__("TTT"+str(e))
exc_type, exc_value, exc_trace = sys.exc_info()
exc_value = new_e
reraise(e.__class__, exc_value, exc_trace)
except theano.gof.toolbox.BadOptimization as e:
pass
else:
assert False
else:
assert False
def test_stochasticoptimization():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论