提交 98603703 authored 作者: James Bergstra's avatar James Bergstra

gradient function checks that returned types match

上级 b2f48f9c
...@@ -89,7 +89,9 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -89,7 +89,9 @@ def grad_sources_inputs(sources, graph_inputs):
node.op, node.op,
len(g_inputs), len(g_inputs),
len(node.inputs)) len(node.inputs))
for r, g_r in zip(node.inputs, g_inputs): for ii, (r, g_r) in enumerate(zip(node.inputs, g_inputs)):
if g_r and (r.type != g_r.type):
print 'WARNING: %s.grad returned a different type for input %i: %s vs. %s'%(node.op, ii, r.type, g_r.type)
if g_r and len(sources) == 1 and sources[0][0].name and r.name: if g_r and len(sources) == 1 and sources[0][0].name and r.name:
g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name) g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name)
if g_r is not None: if g_r is not None:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论