提交 b5d176d0 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

improved error messages for invalid values

updated documentation related to verify_grad and TensorType
上级 0801a94e
...@@ -45,6 +45,7 @@ you should check the strides and alignment. ...@@ -45,6 +45,7 @@ you should check the strides and alignment.
return theano.Apply(self, return theano.Apply(self,
inputs=[x_], inputs=[x_],
outputs=[x_.type()]) outputs=[x_.type()])
# using x_.type() is dangerous, it copies x's broadcasting behaviour
def perform(self, node, inputs, output_storage): def perform(self, node, inputs, output_storage):
x, = inputs x, = inputs
......
...@@ -401,7 +401,8 @@ Here is the prototype for the verify_grad function. ...@@ -401,7 +401,8 @@ Here is the prototype for the verify_grad function.
>>> def verify_grad(fun, pt, n_tests=2, rng=None, eps=1.0e-7, abs_tol=0.0001, rel_tol=0.0001): >>> def verify_grad(fun, pt, n_tests=2, rng=None, eps=1.0e-7, abs_tol=0.0001, rel_tol=0.0001):
``verify_grad`` raises an Exception if the difference between the analytic gradient and ``verify_grad`` raises an Exception if the difference between the analytic gradient and
numerical gradient (computed through the Finite Difference Method) exceeds numerical gradient (computed through the Finite Difference Method) of a random
projection of the fun's output to a scalar exceeds
both the given absolute and relative tolerances. both the given absolute and relative tolerances.
The parameters are as follows: The parameters are as follows:
...@@ -417,7 +418,7 @@ The parameters are as follows: ...@@ -417,7 +418,7 @@ The parameters are as follows:
* ``n_tests``: number of times to run the test * ``n_tests``: number of times to run the test
* ``rng``: random number generator used to generate a random vector u, * ``rng``: random number generator used to generate a random vector u,
we check the gradient of dot(u,fn) at pt we check the gradient of sum(u*fn) at pt
* ``eps``: stepsize used in the Finite Difference Method * ``eps``: stepsize used in the Finite Difference Method
......
...@@ -330,13 +330,14 @@ class StochasticOrder(DebugModeError): ...@@ -330,13 +330,14 @@ class StochasticOrder(DebugModeError):
class InvalidValueError(DebugModeError): class InvalidValueError(DebugModeError):
"""Exception: some Op an output value that is inconsistent with the Type of that output""" """Exception: some Op an output value that is inconsistent with the Type of that output"""
def __init__(self, r, v, client_node=None, hint='none'): def __init__(self, r, v, client_node=None, hint='none', specific_hint='none'):
#super(InvalidValueError, self).__init__() #super(InvalidValueError, self).__init__()
DebugModeError.__init__(self)#to be compatible with python2.4 DebugModeError.__init__(self)#to be compatible with python2.4
self.r = r self.r = r
self.v = v self.v = v
self.client_node = client_node self.client_node = client_node
self.hint=hint self.hint=hint
self.specific_hint=specific_hint
def __str__(self): def __str__(self):
r, v = self.r, self.v r, v = self.r, self.v
...@@ -358,6 +359,7 @@ class InvalidValueError(DebugModeError): ...@@ -358,6 +359,7 @@ class InvalidValueError(DebugModeError):
pass pass
client_node = self.client_node client_node = self.client_node
hint = self.hint hint = self.hint
specific_hint = self.specific_hint
return """InvalidValueError return """InvalidValueError
type(variable) = %(type_r)s type(variable) = %(type_r)s
variable = %(r)s variable = %(r)s
...@@ -370,6 +372,7 @@ class InvalidValueError(DebugModeError): ...@@ -370,6 +372,7 @@ class InvalidValueError(DebugModeError):
isfinite = %(v_isfinite)s isfinite = %(v_isfinite)s
client_node = %(client_node)s client_node = %(client_node)s
hint = %(hint)s hint = %(hint)s
specific_hint = %(specific_hint)s
""" % locals() """ % locals()
######################## ########################
...@@ -1070,7 +1073,7 @@ class _Linker(gof.link.LocalLinker): ...@@ -1070,7 +1073,7 @@ class _Linker(gof.link.LocalLinker):
# check output values for type-correctness # check output values for type-correctness
for r in node.outputs: for r in node.outputs:
if not r.type.is_valid_value(storage_map[r][0]): if not r.type.is_valid_value(storage_map[r][0]):
raise InvalidValueError(r, storage_map[r][0], hint='perform output') raise InvalidValueError(r, storage_map[r][0], hint='perform output', specific_hint = r.type.value_validity_msg(storage_map[r][0]))
#if r in r_vals: #if r in r_vals:
_check_inputs(node, storage_map, r_vals, dr_vals, active_order_set, _check_inputs(node, storage_map, r_vals, dr_vals, active_order_set,
......
...@@ -227,6 +227,10 @@ class PureType(object): ...@@ -227,6 +227,10 @@ class PureType(object):
except (TypeError, ValueError): except (TypeError, ValueError):
return False return False
def value_validity_msg(self, a):
"""Optional: return a message explaining the output of is_valid_value"""
return "none"
def make_variable(self, name = None): def make_variable(self, name = None):
"""Return a new `Variable` instance of Type `self`. """Return a new `Variable` instance of Type `self`.
......
...@@ -19,7 +19,6 @@ from theano import gradient ...@@ -19,7 +19,6 @@ from theano import gradient
import elemwise import elemwise
from theano import scalar as scal from theano import scalar as scal
from theano.gof.python25 import partial, any, all from theano.gof.python25 import partial, any, all
from theano import compile, printing from theano import compile, printing
from theano.printing import pprint, Print from theano.printing import pprint, Print
...@@ -430,6 +429,14 @@ class TensorType(Type): ...@@ -430,6 +429,14 @@ class TensorType(Type):
raise ValueError("non-finite elements not allowed") raise ValueError("non-finite elements not allowed")
return data return data
def value_validity_msg(self, a):
try:
self.filter(a, True)
except Exception, e:
return str(e)
return "value is valid"
def dtype_specs(self): def dtype_specs(self):
"""Return a tuple (python type, c type, numpy typenum) that corresponds to """Return a tuple (python type, c type, numpy typenum) that corresponds to
self.dtype. self.dtype.
...@@ -4046,7 +4053,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No ...@@ -4046,7 +4053,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No
rng=numpy.random) rng=numpy.random)
Raises an Exception if the difference between the analytic gradient and Raises an Exception if the difference between the analytic gradient and
numerical gradient (computed through the Finite Difference Method) exceeds numerical gradient (computed through the Finite Difference Method) of a random
projection of the fun's output to a scalar exceeds
the given tolerance. the given tolerance.
:param fun: a Python function that takes Theano variables as inputs, :param fun: a Python function that takes Theano variables as inputs,
...@@ -4055,7 +4063,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No ...@@ -4055,7 +4063,7 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, abs_tol=None, rel_tol=No
:param pt: the list of numpy.ndarrays to use as input values. :param pt: the list of numpy.ndarrays to use as input values.
These arrays must be either float32 or float64 arrays. These arrays must be either float32 or float64 arrays.
:param n_tests: number of times to run the test :param n_tests: number of times to run the test
:param rng: random number generator used to sample u, we test gradient of dot(u,fun) at pt :param rng: random number generator used to sample u, we test gradient of sum(u * fun) at pt
:param eps: stepsize used in the Finite Difference Method (Default None is type-dependent) :param eps: stepsize used in the Finite Difference Method (Default None is type-dependent)
:param abs_tol: absolute tolerance used as threshold for gradient comparison :param abs_tol: absolute tolerance used as threshold for gradient comparison
:param rel_tol: relative tolerance used as threshold for gradient comparison :param rel_tol: relative tolerance used as threshold for gradient comparison
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论