提交 074c39ef authored 作者: Ian Goodfellow's avatar Ian Goodfellow

NaNType -> NullType following code review. Made NullType raise

ValueErrors
上级 aa0024bb
...@@ -11,8 +11,7 @@ import toolbox ...@@ -11,8 +11,7 @@ import toolbox
from python25 import all from python25 import all
from theano import config from theano import config
import warnings import warnings
NaNType = None NullType = None
class InconsistencyError(Exception): class InconsistencyError(Exception):
""" """
...@@ -212,9 +211,9 @@ class FunctionGraph(utils.object2): ...@@ -212,9 +211,9 @@ class FunctionGraph(utils.object2):
### import ### ### import ###
def __import_r__(self, variables): def __import_r__(self, variables):
global NaNType global NullType
if NaNType is None: if NullType is None:
from nan_type import NaNType from null_type import NullType
# Imports the owners of the variables # Imports the owners of the variables
r_owner_done = set(self.nodes) r_owner_done = set(self.nodes)
for node in [r.owner for r in variables if r.owner is not None]: for node in [r.owner for r in variables if r.owner is not None]:
...@@ -223,8 +222,8 @@ class FunctionGraph(utils.object2): ...@@ -223,8 +222,8 @@ class FunctionGraph(utils.object2):
self.__import__(node) self.__import__(node)
for r in variables: for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if isinstance(r.type,NaNType): if isinstance(r.type,NullType):
raise TypeError("Computation graph contains a NaN. "+r.type.why_nan) raise TypeError("Computation graph contains a NaN. "+r.type.why_null)
raise MissingInputError("Undeclared input", r) raise MissingInputError("Undeclared input", r)
if not getattr(r, 'fgraph', None) is self: if not getattr(r, 'fgraph', None) is self:
self.__setup_r__(r) self.__setup_r__(r)
......
from type import Type from type import Type
from graph import Variable from graph import Variable
class NaNType(Type): class NullType(Type):
def __init__(self, why_nan = '(no explanation given)'): def __init__(self, why_null = '(no explanation given)'):
""" """
why_nan: A string explaining why this variable is NaN why_null: A string explaining why this variable
can't take on any values
""" """
self.why_nan = why_nan self.why_null = why_null
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
raise raise ValueError("No values may be assigned to a NullType")
def filter_variable(self, other): def filter_variable(self, other):
raise raise ValueError("No values may be assigned to a NullType")
def may_share_memory(a, b): def may_share_memory(a, b):
return False return False
def values_eq(a, b, force_same_dtype=True): def values_eq(a, b, force_same_dtype=True):
raise raise ValueError("NullType has no values to compare")
class NaNVariable(Variable): class NullVariable(Variable):
pass pass
...@@ -20,7 +20,7 @@ from theano import gof ...@@ -20,7 +20,7 @@ from theano import gof
from theano.gof import Variable from theano.gof import Variable
from theano.gof.python25 import all from theano.gof.python25 import all
import theano.gof.utils import theano.gof.utils
from theano.gof.nan_type import NaNType from theano.gof.null_type import NullType
from theano.printing import min_informative_str from theano.printing import min_informative_str
tensor = None tensor = None
...@@ -68,7 +68,7 @@ def grad_not_implemented(op, x_pos, x, comment = ""): ...@@ -68,7 +68,7 @@ def grad_not_implemented(op, x_pos, x, comment = ""):
gradient is not implemented. gradient is not implemented.
""" """
return NaNType("This variable is NaN because the grad method for " + \ return NullType("This variable is NaN because the grad method for " + \
"input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \ "input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \
" not implemented."+comment)() " not implemented."+comment)()
...@@ -86,7 +86,7 @@ def grad_undefined(op, x_pos, x, comment = ""): ...@@ -86,7 +86,7 @@ def grad_undefined(op, x_pos, x, comment = ""):
gradient is not defined. gradient is not defined.
""" """
return NaNType("This variable is NaN because the gradient for " + \ return NullType("This variable is NaN because the gradient for " + \
"input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \ "input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \
" mathematically undefined."+comment)() " mathematically undefined."+comment)()
...@@ -375,9 +375,9 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False, ...@@ -375,9 +375,9 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
raise TypeError("cost must be a scalar.") raise TypeError("cost must be a scalar.")
if isinstance(cost.type, NaNType): if isinstance(cost.type, NullType):
raise ValueError("Can't differentiate a NaN cost. cost is NaN because "+\ raise ValueError("Can't differentiate a NaN cost. cost is NaN because "+\
cost.type.why_nan) cost.type.why_null)
if consider_constant is None: if consider_constant is None:
consider_constant = [] consider_constant = []
...@@ -609,9 +609,9 @@ def _populate_grad_dict(var_to_node_to_idx,\ ...@@ -609,9 +609,9 @@ def _populate_grad_dict(var_to_node_to_idx,\
" Variable instance." % (str(node.op), " Variable instance." % (str(node.op),
type(term))) type(term)))
if isinstance(term.type,NaNType): if isinstance(term.type,NullType):
raise TypeError("tensor.grad encountered a NaN. "+\ raise TypeError("tensor.grad encountered a NaN. "+\
term.type.why_nan) term.type.why_null)
terms.append( term) terms.append( term)
grad_dict[var] = nonempty_sum(terms) grad_dict[var] = nonempty_sum(terms)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论