提交 3575d39a authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added ability to put a comment on a bad grad exception

added a new kind of bad grad, undefined grad
上级 7b39a307
......@@ -194,10 +194,10 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
class GradNotImplementedOp(gof.op.UncomputableOp):
""" A BadGradOp representing a gradient that hasn't been implemented yet.
""" An UncomputableOp representing a gradient that hasn't been implemented yet.
"""
def __init__(self, op, x_pos):
def __init__(self, op, x_pos, comment = ""):
"""
op: A theano op whose grad is not implemented for some input
x_pos: An int, giving the index in the op's input list of
......@@ -205,6 +205,8 @@ class GradNotImplementedOp(gof.op.UncomputableOp):
(if op has unimplemented gradients for several inputs,
it must still return a separate UnimplementedGradOp for
each)
comment: An optional comment explaining why the gradient isn't
implemented.
"""
assert isinstance(op, gof.Op)
......@@ -212,11 +214,11 @@ class GradNotImplementedOp(gof.op.UncomputableOp):
assert x_pos >= 0
super(GradNotImplementedOp,self).__init__(NotImplementedError,
"%s does not implement its gradient with respect to input %d" \
% (str(type(op)), x_pos))
"%s does not implement its gradient with respect to input %d. %s" \
% (str(type(op)), x_pos, comment))
def grad_not_implemented(op, x_pos, x):
def grad_not_implemented(op, x_pos, x, comment = ""):
"""
Return an un-computable symbolic variable of type `x.type`.
......@@ -225,9 +227,61 @@ def grad_not_implemented(op, x_pos, x):
raised indicating that the gradient on the
`x_pos`'th input of `op` has not been implemented. Likewise if
any call to theano.function involves this variable.
Optionally adds a comment to the exception explaining why this
gradient is not implemented.
"""
return GradNotImplementedOp(op, x_pos, comment)(x)
class GradUndefinedError(Exception):
""" An exception raised upon attempts to use an undefined gradient.
"""
class GradUndefinedOp(gof.op.UncomputableOp):
""" An UncomputableOp representing a gradient that is mathematically
undefined.
"""
def __init__(self, op, x_pos, comment = ""):
"""
op: A theano op whose grad is mathematically undefined for
some input
x_pos: An int, giving the index in the op's input list of
a variable for which the gradient is undefined
(if op has undefined gradients for several inputs,
it must still return a separate GradUndefinedOp for
each)
comment: An optional comment explaining why the gradient isn't
defined.
"""
assert isinstance(op, gof.Op)
assert isinstance(x_pos, int)
assert x_pos >= 0
super(GradUndefinedOp,self).__init__(GradUndefinedError,
"%s does not implement its gradient with respect to input %d. %s" \
% (str(type(op)), x_pos, comment))
def grad_undefined(op, x_pos, x, comment = ""):
"""
Return an un-computable symbolic variable of type `x.type`.
If any call to tensor.grad results in an expression containing this
un-computable variable, an exception (GradUndefinedError) will be
raised indicating that the gradient on the
`x_pos`'th input of `op` is mathematically undefined. Likewise if
any call to theano.function involves this variable.
Optionally adds a comment to the exception explaining why this
gradient is not defined.
"""
return GradNotImplementedOp(op, x_pos)(x)
return GradUndefinedOp(op, x_pos, comment)(x)
########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论