提交 074c39ef authored 作者: Ian Goodfellow's avatar Ian Goodfellow

NaNType -> NullType following code review. Made NullType raise

ValueErrors
上级 aa0024bb
......@@ -11,8 +11,7 @@ import toolbox
from python25 import all
from theano import config
import warnings
NaNType = None
NullType = None
class InconsistencyError(Exception):
"""
......@@ -212,9 +211,9 @@ class FunctionGraph(utils.object2):
### import ###
def __import_r__(self, variables):
global NaNType
if NaNType is None:
from nan_type import NaNType
global NullType
if NullType is None:
from null_type import NullType
# Imports the owners of the variables
r_owner_done = set(self.nodes)
for node in [r.owner for r in variables if r.owner is not None]:
......@@ -223,8 +222,8 @@ class FunctionGraph(utils.object2):
self.__import__(node)
for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if isinstance(r.type,NaNType):
raise TypeError("Computation graph contains a NaN. "+r.type.why_nan)
if isinstance(r.type,NullType):
raise TypeError("Computation graph contains a NaN. "+r.type.why_null)
raise MissingInputError("Undeclared input", r)
if not getattr(r, 'fgraph', None) is self:
self.__setup_r__(r)
......
from type import Type
from graph import Variable
class NaNType(Type):
def __init__(self, why_nan = '(no explanation given)'):
class NullType(Type):
def __init__(self, why_null = '(no explanation given)'):
"""
why_nan: A string explaining why this variable is NaN
why_null: A string explaining why this variable
can't take on any values
"""
self.why_nan = why_nan
self.why_null = why_null
def filter(self, data, strict=False, allow_downcast=None):
raise
raise ValueError("No values may be assigned to a NullType")
def filter_variable(self, other):
raise
raise ValueError("No values may be assigned to a NullType")
def may_share_memory(a, b):
return False
def values_eq(a, b, force_same_dtype=True):
raise
raise ValueError("NullType has no values to compare")
class NaNVariable(Variable):
class NullVariable(Variable):
pass
......@@ -20,7 +20,7 @@ from theano import gof
from theano.gof import Variable
from theano.gof.python25 import all
import theano.gof.utils
from theano.gof.nan_type import NaNType
from theano.gof.null_type import NullType
from theano.printing import min_informative_str
tensor = None
......@@ -68,7 +68,7 @@ def grad_not_implemented(op, x_pos, x, comment = ""):
gradient is not implemented.
"""
return NaNType("This variable is NaN because the grad method for " + \
return NullType("This variable is NaN because the grad method for " + \
"input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \
" not implemented."+comment)()
......@@ -86,7 +86,7 @@ def grad_undefined(op, x_pos, x, comment = ""):
gradient is not defined.
"""
return NaNType("This variable is NaN because the gradient for " + \
return NullType("This variable is NaN because the gradient for " + \
"input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \
" mathematically undefined."+comment)()
......@@ -375,9 +375,9 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
raise TypeError("cost must be a scalar.")
if isinstance(cost.type, NaNType):
if isinstance(cost.type, NullType):
raise ValueError("Can't differentiate a NaN cost. cost is NaN because "+\
cost.type.why_nan)
cost.type.why_null)
if consider_constant is None:
consider_constant = []
......@@ -609,9 +609,9 @@ def _populate_grad_dict(var_to_node_to_idx,\
" Variable instance." % (str(node.op),
type(term)))
if isinstance(term.type,NaNType):
if isinstance(term.type,NullType):
raise TypeError("tensor.grad encountered a NaN. "+\
term.type.why_nan)
term.type.why_null)
terms.append( term)
grad_dict[var] = nonempty_sum(terms)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论