提交 cf2d9d17 authored 作者: James Bergstra's avatar James Bergstra

gradient - added some asserts and comments

上级 ad06351a
...@@ -32,6 +32,8 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -32,6 +32,8 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
""" """
gmap = {} gmap = {}
for (r, g_r) in sources: for (r, g_r) in sources:
if not hasattr(r, 'type'):
raise TypeError('sources must be Variables', r)
if g_r is not None: if g_r is not None:
if r in gmap: if r in gmap:
gmap[r] = gmap[r] + g_r gmap[r] = gmap[r] + g_r
...@@ -52,6 +54,10 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -52,6 +54,10 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
output_arg = g_outputs output_arg = g_outputs
input_arg = node.inputs input_arg = node.inputs
# Each Op's grad function requires inputs and output_grads
# If the Op destroys any input, but the grad expression uses it, then chances are the
# resulting graph will have a dependency cycle. We avoid this cycle by passing
# (symbolic) copies of each destroyed input.
try: try:
dinputs = [node.inputs[x[0]] for x in node.op.destroy_map.values()] dinputs = [node.inputs[x[0]] for x in node.op.destroy_map.values()]
except AttributeError: except AttributeError:
...@@ -93,6 +99,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -93,6 +99,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
if g_r and len(sources) == 1 and sources[0][0].name and r.name: if g_r and len(sources) == 1 and sources[0][0].name and r.name:
g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name) g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name)
if g_r is not None: if g_r is not None:
assert r is not None
if r in gmap: if r in gmap:
gmap[r] = gmap[r] + g_r gmap[r] = gmap[r] + g_r
else: else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论