提交 a233573d authored 作者: Ian Goodfellow's avatar Ian Goodfellow

started work on fixing gradient issues

上级 2317ffbb
......@@ -36,7 +36,14 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
them)
:rtype: dictionary whose keys and values are of type `Variable`
:return: mapping from each Variable encountered in the backward traversal to its gradient.
:return: mapping from each Variable encountered in the backward traversal to the gradient with respect to that Variable.
It is assumed that there is some objective J shared between all members of
sources, so that for each v, gradient-on-v is the gradient of J with respect to v
"""
gmap = {}
for (r, g_r) in sources:
......@@ -125,3 +132,14 @@ def unimplemented_grad(op, x_pos, x):
"""
msg = '%s.grad not implemented for input %i'%(op, x_pos)
return Raise(msg=msg)(x)
class GradientUndefined(Exception): pass
def undefined_grad(op, x_pos, x):
msg = "Undefined gradient - do not use in computations"
exc = RuntimeError
return Raise(msg=msg, exc=exc)(x)
def grad(self, inputs, out_storage):
return [g_x0, undefined_grad(self, 1, inputs[1])]
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论