提交 3ea17d76 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added support for warn_type

上级 854e32c6
......@@ -444,7 +444,7 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
# Gradient
#########################
def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignored',
def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = False,
disconnected_inputs = 'raise'):
global tensor
if tensor is None:
......@@ -558,8 +558,18 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
term_dict[node] = list(input_grads)
for i in xrange(len(term_dict[node])):
if term_dict[node][i] is None:
term_dict[node][i] = node.inputs[i].zeros_like()
if warn_type:
g_r_type = term_dict[node][i].type
r_type = inputs[i].type
if g_r_type != r_type:
_logger.warning('%s.grad returned a different type (%s) '
'for input %i of type (%s)',
node.op, g_r_type, i, r_type)
return term_dict[node]
......@@ -612,7 +622,7 @@ def grad(cost, wrt, g_cost = None, consider_constant = None, warn_type = 'ignore
def grad_sources_inputs(sources, graph_inputs, warn_type = 'ignored'):
def grad_sources_inputs(sources, graph_inputs, warn_type = True):
outputs, output_grads = zip(*sources)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论