提交 a233573d authored 作者: Ian Goodfellow's avatar Ian Goodfellow

started work on fixing gradient issues

上级 2317ffbb
...@@ -36,7 +36,14 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -36,7 +36,14 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
them) them)
:rtype: dictionary whose keys and values are of type `Variable` :rtype: dictionary whose keys and values are of type `Variable`
:return: mapping from each Variable encountered in the backward traversal to its gradient. :return: mapping from each Variable encountered in the backward traversal to the gradient with respect to that Variable.
It is assumed that there is some objective J shared between all members of
sources, so that for each v, gradient-on-v is the gradient of J with respect to v
""" """
gmap = {} gmap = {}
for (r, g_r) in sources: for (r, g_r) in sources:
...@@ -78,7 +85,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -78,7 +85,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
else: else:
new_input_arg.append(input) new_input_arg.append(input)
input_arg = new_input_arg input_arg = new_input_arg
#note that this function is not in a try-except block #note that this function is not in a try-except block
# the rationale: # the rationale:
# If the op implements grad, then any exception should be passed to the # If the op implements grad, then any exception should be passed to the
...@@ -93,8 +100,8 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -93,8 +100,8 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
g_inputs = op_grad g_inputs = op_grad
assert isinstance(g_inputs, (list, tuple)) assert isinstance(g_inputs, (list, tuple))
if len(g_inputs) != len(node.inputs): if len(g_inputs) != len(node.inputs):
raise ValueError(_msg_badlen, raise ValueError(_msg_badlen,
node.op, node.op,
len(g_inputs), len(g_inputs),
len(node.inputs)) len(node.inputs))
for ii, (r, g_r) in enumerate(zip(node.inputs, g_inputs)): for ii, (r, g_r) in enumerate(zip(node.inputs, g_inputs)):
...@@ -106,7 +113,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -106,7 +113,7 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
node.op, g_r_type, ii, r_type)) node.op, g_r_type, ii, r_type))
if g_r and len(sources) == 1 and sources[0][0].name and r.name: if g_r and len(sources) == 1 and sources[0][0].name and r.name:
g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name) g_r.name = "(d%s/d%s)" % (sources[0][0].name, r.name)
if g_r is not None: if g_r is not None:
assert r is not None assert r is not None
if r in gmap: if r in gmap:
gmap[r] = gmap[r] + g_r gmap[r] = gmap[r] + g_r
...@@ -125,3 +132,14 @@ def unimplemented_grad(op, x_pos, x): ...@@ -125,3 +132,14 @@ def unimplemented_grad(op, x_pos, x):
""" """
msg = '%s.grad not implemented for input %i'%(op, x_pos) msg = '%s.grad not implemented for input %i'%(op, x_pos)
return Raise(msg=msg)(x) return Raise(msg=msg)(x)
class GradientUndefined(Exception): pass
def undefined_grad(op, x_pos, x):
msg = "Undefined gradient - do not use in computations"
exc = RuntimeError
return Raise(msg=msg, exc=exc)(x)
def grad(self, inputs, out_storage):
return [g_x0, undefined_grad(self, 1, inputs[1])]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论