提交 3575d39a authored 作者: Ian Goodfellow's avatar Ian Goodfellow

added ability to put a comment on a bad grad exception

added a new kind of bad grad, undefined grad
上级 7b39a307
...@@ -194,10 +194,10 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True): ...@@ -194,10 +194,10 @@ def grad_sources_inputs(sources, graph_inputs, warn_type=True):
class GradNotImplementedOp(gof.op.UncomputableOp): class GradNotImplementedOp(gof.op.UncomputableOp):
""" A BadGradOp representing a gradient that hasn't been implemented yet. """ An UncomputableOp representing a gradient that hasn't been implemented yet.
""" """
def __init__(self, op, x_pos): def __init__(self, op, x_pos, comment = ""):
""" """
op: A theano op whose grad is not implemented for some input op: A theano op whose grad is not implemented for some input
x_pos: An int, giving the index in the op's input list of x_pos: An int, giving the index in the op's input list of
...@@ -205,6 +205,8 @@ class GradNotImplementedOp(gof.op.UncomputableOp): ...@@ -205,6 +205,8 @@ class GradNotImplementedOp(gof.op.UncomputableOp):
(if op has unimplemented gradients for several inputs, (if op has unimplemented gradients for several inputs,
it must still return a separate UnimplementedGradOp for it must still return a separate UnimplementedGradOp for
each) each)
comment: An optional comment explaining why the gradient isn't
implemented.
""" """
assert isinstance(op, gof.Op) assert isinstance(op, gof.Op)
...@@ -212,11 +214,11 @@ class GradNotImplementedOp(gof.op.UncomputableOp): ...@@ -212,11 +214,11 @@ class GradNotImplementedOp(gof.op.UncomputableOp):
assert x_pos >= 0 assert x_pos >= 0
super(GradNotImplementedOp,self).__init__(NotImplementedError, super(GradNotImplementedOp,self).__init__(NotImplementedError,
"%s does not implement its gradient with respect to input %d" \ "%s does not implement its gradient with respect to input %d. %s" \
% (str(type(op)), x_pos)) % (str(type(op)), x_pos, comment))
def grad_not_implemented(op, x_pos, x): def grad_not_implemented(op, x_pos, x, comment = ""):
""" """
Return an un-computable symbolic variable of type `x.type`. Return an un-computable symbolic variable of type `x.type`.
...@@ -225,9 +227,61 @@ def grad_not_implemented(op, x_pos, x): ...@@ -225,9 +227,61 @@ def grad_not_implemented(op, x_pos, x):
raised indicating that the gradient on the raised indicating that the gradient on the
`x_pos`'th input of `op` has not been implemented. Likewise if `x_pos`'th input of `op` has not been implemented. Likewise if
any call to theano.function involves this variable. any call to theano.function involves this variable.
Optionally adds a comment to the exception explaining why this
gradient is not implemented.
"""
return GradNotImplementedOp(op, x_pos, comment)(x)
class GradUndefinedError(Exception):
""" An exception raised upon attempts to use an undefined gradient.
"""
class GradUndefinedOp(gof.op.UncomputableOp):
""" An UncomputableOp representing a gradient that is mathematically
undefined.
"""
def __init__(self, op, x_pos, comment = ""):
"""
op: A theano op whose grad is mathematically undefined for
some input
x_pos: An int, giving the index in the op's input list of
a variable for which the gradient is undefined
(if op has undefined gradients for several inputs,
it must still return a separate GradUndefinedOp for
each)
comment: An optional comment explaining why the gradient isn't
defined.
"""
assert isinstance(op, gof.Op)
assert isinstance(x_pos, int)
assert x_pos >= 0
super(GradUndefinedOp,self).__init__(GradUndefinedError,
"%s does not implement its gradient with respect to input %d. %s" \
% (str(type(op)), x_pos, comment))
def grad_undefined(op, x_pos, x, comment = ""):
"""
Return an un-computable symbolic variable of type `x.type`.
If any call to tensor.grad results in an expression containing this
un-computable variable, an exception (GradUndefinedError) will be
raised indicating that the gradient on the
`x_pos`'th input of `op` is mathematically undefined. Likewise if
any call to theano.function involves this variable.
Optionally adds a comment to the exception explaining why this
gradient is not defined.
""" """
return GradNotImplementedOp(op, x_pos)(x) return GradUndefinedOp(op, x_pos, comment)(x)
######################## ########################
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论