提交 1cb68ffc authored 作者: James Bergstra's avatar James Bergstra 提交者: Ian Goodfellow

ENH: pep8 in gradient.py

上级 c53a8e84
......@@ -26,10 +26,11 @@ tensor = None
_msg_retType = 'op.grad(...) returned a non-list'
def format_as(use_list, use_tuple, outputs):
"""
Formats the outputs according to the flags `use_list` and `use_tuple`.
If `use_list` is True, `outputs` is returned as a list (if `outputs`
" If `use_list` is True, `outputs` is returned as a list (if `outputs`
is not a list or a tuple then it is converted in a one element list).
If `use_tuple` is True, `outputs` is returned as a tuple (if `outputs`
is not a list or a tuple then it is converted into a one element tuple).
......@@ -54,7 +55,8 @@ def format_as(use_list, use_tuple, outputs):
else:
return outputs
def grad_not_implemented(op, x_pos, x, comment = ""):
def grad_not_implemented(op, x_pos, x, comment=""):
"""
Return an un-computable symbolic variable of type `x.type`.
......@@ -68,11 +70,14 @@ def grad_not_implemented(op, x_pos, x, comment = ""):
gradient is not implemented.
"""
return NullType("This variable is NaN because the grad method for " + \
"input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \
" not implemented."+comment)()
return NullType(
(
"This variable is Null because the grad method for "
"input %s (%s) of the %s op is not implemented. %s"
) % (x_pos, x, op, comment))
def grad_undefined(op, x_pos, x, comment = ""):
def grad_undefined(op, x_pos, x, comment=""):
"""
Return an un-computable symbolic variable of type `x.type`.
......@@ -86,9 +91,12 @@ def grad_undefined(op, x_pos, x, comment = ""):
gradient is not defined.
"""
return NullType("This variable is NaN because the gradient for " + \
"input "+str(x_pos)+" ("+str(x)+") of the "+str(op)+" op is" + \
" mathematically undefined."+comment)()
return NullType(
(
"This variable is Null because the grad method for "
"input %s (%s) of the %s op is mathematically undefined. %s"
) % (x_pos, x, op, comment))
class DisconnectedType(theano.gof.type.Type):
......@@ -105,9 +113,12 @@ class DisconnectedType(theano.gof.type.Type):
whose type doesn't support zeros_like has 0 gradient. """
def filter(self, data, strict=False, allow_downcast=None):
raise AssertionError("If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as "
"symbolic placeholder.")
raise AssertionError(
(
"If you're assigning to a DisconnectedType you're"
" doing something wrong. It should only be used as"
" a symbolic placeholder."
))
def fiter_variable(self, other):
raise
......@@ -253,7 +264,7 @@ def Rop(f, wrt, eval_points):
def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
disconnected_inputs='raise'):
disconnected_inputs='raise'):
"""
Computes the L operation on `f` wrt to `wrt` evaluated at points given
in `eval_points`. Mathematically this stands for the jacobian of `f` wrt
......@@ -288,8 +299,8 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
if not isinstance(f, (list, tuple)):
f = [f]
f = [ elem for elem in f ]
grads = [ elem for elem in eval_points ]
f = [elem for elem in f]
grads = [elem for elem in eval_points]
for elem in consider_constant:
assert elem not in f
......@@ -319,10 +330,11 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
if p in gmap:
ret.append(gmap[p])
else:
message = ("Lop method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % p)
message = (
"Lop method was asked to compute the gradient "
"with respect to a variable that is not part of "
"the computational graph of the cost, or is used "
"only by a non-differentiable operator: %s" % p)
if disconnected_inputs == 'ignore':
pass
elif disconnected_inputs == 'warn':
......@@ -330,9 +342,10 @@ def Lop(f, wrt, eval_points, consider_constant=None, warn_type=False,
elif disconnected_inputs == 'raise':
raise ValueError(message)
else:
raise ValueError("Invalid value for keyword "
"'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
raise ValueError(
"Invalid value for keyword "
"'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
ret.append(p.zeros_like())
return format_as(using_list, using_tuple, ret)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论