提交 eb484245 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

made FunctionGraph check for NaNs

上级 270d35fa
...@@ -11,6 +11,7 @@ import toolbox ...@@ -11,6 +11,7 @@ import toolbox
from python25 import all from python25 import all
from theano import config from theano import config
import warnings import warnings
NaNType = None
class InconsistencyError(Exception): class InconsistencyError(Exception):
...@@ -211,6 +212,9 @@ class FunctionGraph(utils.object2): ...@@ -211,6 +212,9 @@ class FunctionGraph(utils.object2):
### import ### ### import ###
def __import_r__(self, variables): def __import_r__(self, variables):
global NaNType
if NaNType is None:
from nan_type import NaNType
# Imports the owners of the variables # Imports the owners of the variables
r_owner_done = set(self.nodes) r_owner_done = set(self.nodes)
for node in [r.owner for r in variables if r.owner is not None]: for node in [r.owner for r in variables if r.owner is not None]:
...@@ -219,6 +223,8 @@ class FunctionGraph(utils.object2): ...@@ -219,6 +223,8 @@ class FunctionGraph(utils.object2):
self.__import__(node) self.__import__(node)
for r in variables: for r in variables:
if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs: if r.owner is None and not isinstance(r, graph.Constant) and r not in self.inputs:
if isinstance(r.type,NaNType):
raise TypeError("Computation graph contains a NaN. "+r.type.why_nan)
raise MissingInputError("Undeclared input", r) raise MissingInputError("Undeclared input", r)
if not getattr(r, 'fgraph', None) is self: if not getattr(r, 'fgraph', None) is self:
self.__setup_r__(r) self.__setup_r__(r)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论