提交 1cb68ffc authored 作者: James Bergstra's avatar James Bergstra 提交者: Ian Goodfellow

ENH: pep8 in gradient.py

上级 c53a8e84
...@@ -26,10 +26,11 @@ tensor = None ...@@ -26,10 +26,11 @@ tensor = None
_msg_retType = 'op.grad(...) returned a non-list' _msg_retType = 'op.grad(...) returned a non-list'
def format_as(use_list, use_tuple, outputs): def format_as(use_list, use_tuple, outputs):
""" """
Formats the outputs according to the flags `use_list` and `use_tuple`. Formats the outputs according to the flags `use_list` and `use_tuple`.
If `use_list` is True, `outputs` is returned as a list (if `outputs` " If `use_list` is True, `outputs` is returned as a list (if `outputs`
is not a list or a tuple then it is converted in a one element list). is not a list or a tuple then it is converted in a one element list).
If `use_tuple` is True, `outputs` is returned as a tuple (if `outputs` If `use_tuple` is True, `outputs` is returned as a tuple (if `outputs`
is not a list or a tuple then it is converted into a one element tuple). is not a list or a tuple then it is converted into a one element tuple).
...@@ -54,7 +55,8 @@ def format_as(use_list, use_tuple, outputs): ...@@ -54,7 +55,8 @@ def format_as(use_list, use_tuple, outputs):
else: else:
return outputs return outputs
def grad_not_implemented(op, x_pos, x, comment = ""):
def grad_not_implemented(op, x_pos, x, comment=""):
""" """
Return an un-computable symbolic variable of type `x.type`. Return an un-computable symbolic variable of type `x.type`.
...@@ -68,11 +70,14 @@ def grad_not_implemented(op, x_pos, x, comment = ""): ...@@ -68,11 +70,14 @@ def grad_not_implemented(op, x_pos, x, comment = ""):
gradient is not implemented. gradient is not implemented.
""" """
return NullType("This variable is NaN because the grad method for " + \ return NullType(
"input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \ (
" not implemented."+comment)() "This variable is Null because the grad method for "
"input %s (%s) of the %s op is not implemented. %s"
) % (x_pos, x, op, comment))
def grad_undefined(op, x_pos, x, comment = ""): def grad_undefined(op, x_pos, x, comment=""):
""" """
Return an un-computable symbolic variable of type `x.type`. Return an un-computable symbolic variable of type `x.type`.
...@@ -86,9 +91,12 @@ def grad_undefined(op, x_pos, x, comment = ""): ...@@ -86,9 +91,12 @@ def grad_undefined(op, x_pos, x, comment = ""):
gradient is not defined. gradient is not defined.
""" """
return NullType("This variable is NaN because the gradient for " + \ return NullType(
"input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \ (
" mathematically undefined."+comment)() "This variable is Null because the grad method for "
"input %s (%s) of the %s op is mathematically undefined. %s"
) % (x_pos, x, op, comment))
class DisconnectedType(theano.gof.type.Type): class DisconnectedType(theano.gof.type.Type):
...@@ -105,9 +113,12 @@ class DisconnectedType(theano.gof.type.Type): ...@@ -105,9 +113,12 @@ class DisconnectedType(theano.gof.type.Type):
whose type doesn't support zeros_like has 0 gradient. """ whose type doesn't support zeros_like has 0 gradient. """
def filter(self, data, strict=False, allow_downcast=None): def filter(self, data, strict=False, allow_downcast=None):
raise AssertionError("If you're assigning to a DisconnectedType you're" raise AssertionError(
" doing something wrong. It should only be used as " (
"symbolic placeholder.") "If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
))
def fiter_variable(self, other): def fiter_variable(self, other):
raise raise
...@@ -253,7 +264,7 @@ def Rop(f, wrt, eval_points): ...@@ -253,7 +264,7 @@ def Rop(f, wrt, eval_points):
def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
disconnected_inputs='raise'): disconnected_inputs='raise'):
""" """
Computes the L operation on `f` wrt to `wrt` evaluated at points given Computes the L operation on `f` wrt to `wrt` evaluated at points given
in `eval_points`. Mathematically this stands for the jacobian of `f` wrt in `eval_points`. Mathematically this stands for the jacobian of `f` wrt
...@@ -288,8 +299,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -288,8 +299,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
if not isinstance(f, (list, tuple)): if not isinstance(f, (list, tuple)):
f = [f] f = [f]
f = [ elem for elem in f ] f = [elem for elem in f]
grads = [ elem for elem in eval_points ] grads = [elem for elem in eval_points]
for elem in consider_constant: for elem in consider_constant:
assert elem not in f assert elem not in f
...@@ -319,10 +330,11 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -319,10 +330,11 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
if p in gmap: if p in gmap:
ret.append(gmap[p]) ret.append(gmap[p])
else: else:
message = ("Lop method was asked to compute the gradient " message = (
"with respect to a variable that is not part of " "Lop method was asked to compute the gradient "
"the computational graph of the cost, or is used " "with respect to a variable that is not part of "
"only by a non-differentiable operator: %s" % p) "the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % p)
if disconnected_inputs == 'ignore': if disconnected_inputs == 'ignore':
pass pass
elif disconnected_inputs == 'warn': elif disconnected_inputs == 'warn':
...@@ -330,9 +342,10 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False, ...@@ -330,9 +342,10 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
elif disconnected_inputs == 'raise': elif disconnected_inputs == 'raise':
raise ValueError(message) raise ValueError(message)
else: else:
raise ValueError("Invalid value for keyword " raise ValueError(
"'disconnected_inputs', valid values are " "Invalid value for keyword "
"'ignore', 'warn' and 'raise'.") "'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
ret.append(p.zeros_like()) ret.append(p.zeros_like())
return format_as(using_list, using_tuple, ret) return format_as(using_list, using_tuple, ret)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论