提交 1bde7f38 authored 作者: Olivier Delalleau's avatar Olivier Delalleau

Merge pull request #500 from lamblin/better_exceptions

Better exceptions
...@@ -21,6 +21,12 @@ import logging ...@@ -21,6 +21,12 @@ import logging
_logger = logging.getLogger('theano.compile.function_module') _logger = logging.getLogger('theano.compile.function_module')
class UnusedInputError(Exception):
"""
A symbolic input passed to function is not needed
"""
pass
def alias_root(v): def alias_root(v):
"""Return the variable to which v is aliased by view_maps and destroy_maps""" """Return the variable to which v is aliased by view_maps and destroy_maps"""
if v.owner is None: return v if v.owner is None: return v
...@@ -1110,7 +1116,7 @@ class FunctionMaker(object): ...@@ -1110,7 +1116,7 @@ class FunctionMaker(object):
if on_unused_input == 'warn': if on_unused_input == 'warn':
warnings.warn(msg % (i.variable, warn_msg), stacklevel=5) warnings.warn(msg % (i.variable, warn_msg), stacklevel=5)
elif on_unused_input == 'raise': elif on_unused_input == 'raise':
raise ValueError(msg % (i.variable, err_msg)) raise UnusedInputError(msg % (i.variable, err_msg))
else: else:
raise ValueError(("Invalid value for keyword " raise ValueError(("Invalid value for keyword "
"on_unused_input of theano.function: '%s'. " "on_unused_input of theano.function: '%s'. "
......
...@@ -7,6 +7,7 @@ from profiling import ProfileStats ...@@ -7,6 +7,7 @@ from profiling import ProfileStats
from theano import config from theano import config
from theano.compile import orig_function, In, Out from theano.compile import orig_function, In, Out
from theano.compile import UnusedInputError
from theano.compile.sharedvalue import SharedVariable, shared from theano.compile.sharedvalue import SharedVariable, shared
from theano.gof import Container, Variable, generic, graph, Constant, Value from theano.gof import Container, Variable, generic, graph, Constant, Value
from theano.gof.python25 import any from theano.gof.python25 import any
...@@ -432,7 +433,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[], ...@@ -432,7 +433,7 @@ def pfunc(params, outputs=None, mode=None, updates=[], givens=[],
for i, v in enumerate(in_variables): for i, v in enumerate(in_variables):
if v in in_variables[(i + 1):]: if v in in_variables[(i + 1):]:
dup_v_i = in_variables.index(v, (i + 1)) dup_v_i = in_variables.index(v, (i + 1))
raise ValueError( raise UnusedInputError(
("Variable %s is used twice in inputs to theano.function, " ("Variable %s is used twice in inputs to theano.function, "
"at indices %i and %i. This would result in values " "at indices %i and %i. This would result in values "
"provided for it being ignored. Please do not duplicate " "provided for it being ignored. Please do not duplicate "
......
...@@ -4,10 +4,11 @@ import numpy ...@@ -4,10 +4,11 @@ import numpy
import unittest import unittest
from theano import gof,config from theano import config, gof
from theano.scalar import mul
from theano.compile.io import In, Out from theano.compile.io import In, Out
from theano.compile import function from theano.compile import function
from theano.compile import UnusedInputError
from theano.gof import MissingInputError
from theano.gof.python25 import all, any from theano.gof.python25 import all, any
from theano import tensor from theano import tensor
...@@ -53,56 +54,53 @@ class T_function(unittest.TestCase): ...@@ -53,56 +54,53 @@ class T_function(unittest.TestCase):
def test_missing_inputs(self): def test_missing_inputs(self):
MissingInputException = TypeError
UnusedInputException = ValueError
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
fn = function([], [x]) fn = function([], [x])
checkfor(self, fn, MissingInputException) checkfor(self, fn, MissingInputError)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
# Ignore unused input s, as it hides the other error # Ignore unused input s, as it hides the other error
fn = function([s], [x], on_unused_input='ignore') fn = function([s], [x], on_unused_input='ignore')
checkfor(self, fn, MissingInputException) checkfor(self, fn, MissingInputError)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
fn = function([s], [x]) fn = function([s], [x])
checkfor(self, fn, UnusedInputException) checkfor(self, fn, UnusedInputError)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
# Ignore unused input s, as it hides the other error # Ignore unused input s, as it hides the other error
fn = function([s], x, on_unused_input='ignore') fn = function([s], x, on_unused_input='ignore')
checkfor(self, fn, MissingInputException) checkfor(self, fn, MissingInputError)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
fn = function([s], x) fn = function([s], x)
checkfor(self, fn, UnusedInputException) checkfor(self, fn, UnusedInputError)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
# Ignore unused input s, as it hides the other error # Ignore unused input s, as it hides the other error
fn = function([s], Out(x), on_unused_input='ignore') fn = function([s], Out(x), on_unused_input='ignore')
checkfor(self, fn, MissingInputException) checkfor(self, fn, MissingInputError)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
fn = function([s], Out(x)) fn = function([s], Out(x))
checkfor(self, fn, UnusedInputException) checkfor(self, fn, UnusedInputError)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
fn = function([In(x, update=s+x)], x) fn = function([In(x, update=s+x)], x)
checkfor(self, fn, MissingInputException) checkfor(self, fn, MissingInputError)
def fn(): def fn():
x,s = T.scalars('xs') x,s = T.scalars('xs')
fn = function([In(x, update=mul(s,s)+x)], x) fn = function([In(x, update=((s * s) + x))], x)
checkfor(self, fn, MissingInputException) checkfor(self, fn, MissingInputError)
def test_input_anon_singleton(self): def test_input_anon_singleton(self):
x,s = T.scalars('xs') x,s = T.scalars('xs')
...@@ -378,14 +376,14 @@ class T_function(unittest.TestCase): ...@@ -378,14 +376,14 @@ class T_function(unittest.TestCase):
def test_disconnected_input(self): def test_disconnected_input(self):
a = T.scalar('a') a = T.scalar('a')
v = T.vector('v') v = T.vector('v')
self.assertRaises(ValueError, function, [a, v], v*2) self.assertRaises(UnusedInputError, function, [a, v], v*2)
f = function([a, v], v*2, on_unused_input='ignore') f = function([a, v], v*2, on_unused_input='ignore')
def test_masked_input(self): def test_masked_input(self):
m = T.matrix('m') m = T.matrix('m')
mt = m.T mt = m.T
mt.name = 'm.T' mt.name = 'm.T'
self.assertRaises(ValueError, function, [m, mt], mt*2) self.assertRaises(UnusedInputError, function, [m, mt], mt*2)
f = function([m, mt], mt*2, on_unused_input='ignore') f = function([m, mt], mt*2, on_unused_input='ignore')
......
...@@ -6,7 +6,7 @@ from cc import \ ...@@ -6,7 +6,7 @@ from cc import \
import compiledir # adds config vars import compiledir # adds config vars
from env import \ from env import \
InconsistencyError, Env InconsistencyError, MissingInputError, Env
from destroyhandler import \ from destroyhandler import \
DestroyHandler DestroyHandler
......
...@@ -15,6 +15,12 @@ class InconsistencyError(Exception): ...@@ -15,6 +15,12 @@ class InconsistencyError(Exception):
""" """
pass pass
class MissingInputError(Exception):
"""
A symbolic input needed to compute the outputs is missing.
"""
pass
class Env(utils.object2): class Env(utils.object2):
...@@ -213,7 +219,7 @@ class Env(utils.object2): ...@@ -213,7 +219,7 @@ class Env(utils.object2):
self.__import__(node) self.__import__(node)
for r in variables: for r in variables:
if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Value) and r not in self.inputs:
raise TypeError("Undeclared input", r) raise MissingInputError("Undeclared input", r)
if not getattr(r, 'env', None) is self: if not getattr(r, 'env', None) is self:
self.__setup_r__(r) self.__setup_r__(r)
self.variables.add(r) self.variables.add(r)
...@@ -285,12 +291,19 @@ class Env(utils.object2): ...@@ -285,12 +291,19 @@ class Env(utils.object2):
#handler code in the first place #handler code in the first place
assert path is not None assert path is not None
raise TypeError('A variable that is an input to the graph was neither provided as an ' raise MissingInputError((
'input to the function nor given a value. A chain of variables leading from ' 'A variable that is an input to the graph was '
'this input to an output is '+str(path)+'. This chain may not be unique') 'neither provided as an input to the function '
'nor given a value. A chain of variables '
'leading from this input to an output is %s. '
'This chain may not be unique' % str(path)))
#Standard error message #Standard error message
raise TypeError("An input of the graph, used to compute "+str(node)+", was not provided and not given a value", r) raise MissingInputError((
"An input of the graph, used to compute %s, "
"was not provided and not given a value"
% str(node)),
r)
for node in new_nodes: for node in new_nodes:
assert node not in self.nodes assert node not in self.nodes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论