提交 0bdb4850 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added some type checking

上级 be799058
...@@ -578,6 +578,11 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore ...@@ -578,6 +578,11 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
idx = node_to_idx[node] idx = node_to_idx[node]
term = access_term_cache(node)[idx] term = access_term_cache(node)[idx]
if not isinstance(term, gof.Variable):
raise TypeError("%s.grad returned %s, expected"
" Variable instance." % (str(node.op),
type(term)))
if isinstance(term.type,NaNType): if isinstance(term.type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\ raise TypeError("tensor.grad encountered a NaN. "+\
term.type.why_nan) term.type.why_nan)
...@@ -702,6 +707,11 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'): ...@@ -702,6 +707,11 @@ def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
idx = node_to_idx[node] idx = node_to_idx[node]
term = access_term_cache(node)[idx] term = access_term_cache(node)[idx]
if not isinstance(term, gof.Variable):
raise TypeError("%s.grad returned %s, expected"
" Variable instance." % (str(node.op),
type(term)))
if isinstance(term.type,NaNType): if isinstance(term.type,NaNType):
raise TypeError("tensor.grad encountered a NaN. "+\ raise TypeError("tensor.grad encountered a NaN. "+\
term.type.why_nan) term.type.why_nan)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论