提交 34bec6f3 authored 作者: Frederic Bastien's avatar Frederic Bastien

some pep8

上级 08e5dab2
......@@ -503,7 +503,6 @@ def grad(cost, wrt, consider_constant=None,
grad_dict[var] = g_var
def handle_disconnected(var):
message = ("grad method was asked to compute the gradient "
"with respect to a variable that is not part of "
......@@ -520,7 +519,6 @@ def grad(cost, wrt, consider_constant=None,
"'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.")
# variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_app_to_idx won't cause an error below
......@@ -705,12 +703,12 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
wrt_grads = list(pgrads[k] for k in wrt)
end_grads = list(pgrads[k] for k in end)
if details:
return wrt_grads, end_grads, start_grads, cost_grads
return wrt_grads, end_grads
def _node_to_pattern(node):
""" given an apply node, obtain its connection pattern
this is just a wrapper around Op.connection_pattern
......@@ -722,30 +720,31 @@ def _node_to_pattern(node):
connection_pattern = node.op.connection_pattern(node)
if not isinstance(connection_pattern, list):
raise TypeError("Op.connection_pattern should return " + \
("list of list of bool, but for Op=%s" % node.op) +\
raise TypeError(
"Op.connection_pattern should return " +
("list of list of bool, but for Op=%s" % node.op) +
"got %s with type %s." % (connection_pattern,
type(connection_pattern)))
if len(connection_pattern) != len(node.inputs):
raise ValueError('%s.connection_pattern should have %d' %
raise ValueError(
'%s.connection_pattern should have %d' %
(node.op, len(node.inputs)) + ' rows but has %d.' %
len(connection_pattern))
for ii, output_pattern in enumerate(connection_pattern):
if not isinstance(output_pattern, list):
raise TypeError('%s.connection_pattern should return' %
node.op + ' a list of lists, but element %d' % ii\
raise TypeError(
'%s.connection_pattern should return' %
node.op + ' a list of lists, but element %d' % ii
+ 'is %s of type %s.' % (output_pattern,
type(output_pattern)))
else:
connection_pattern = \
[[True for output in node.outputs]
connection_pattern = [[True for output in node.outputs]
for ipt in node.inputs]
assert isinstance(connection_pattern, list)
assert len(connection_pattern) == len(node.inputs)
for ii in xrange(len(node.inputs)):
assert isinstance(connection_pattern[ii], list)
assert len(connection_pattern[ii]) == \
len(node.outputs)
assert len(connection_pattern[ii]) == len(node.outputs)
return connection_pattern
......@@ -975,7 +974,8 @@ def _populate_grad_dict(var_to_app_to_idx,
for output in output_grads]
# List of bools indicating if each input only has NullType outputs
only_connected_to_nan = [(True not in
only_connected_to_nan = [
(True not in
[in_to_out and out_to_cost and not out_nan
for in_to_out, out_to_cost, out_nan in
zip(in_to_outs, outputs_connected, ograd_is_nan)])
......@@ -1021,8 +1021,6 @@ def _populate_grad_dict(var_to_app_to_idx,
inputs = [try_to_copy_if_needed(ipt) for ipt in inputs]
# Build a list of output gradients with the same dtype as
# the corresponding output variable.
# If an output is of a float dtype, we want to cast the
......@@ -1116,7 +1114,8 @@ def _populate_grad_dict(var_to_app_to_idx,
# Do type checking on the result
# List of bools indicating if each input only has integer outputs
only_connected_to_int = [(True not in
only_connected_to_int = [
(True not in
[in_to_out and out_to_cost and not out_int
for in_to_out, out_to_cost, out_int in
zip(in_to_outs, outputs_connected, output_is_int)])
......@@ -1130,7 +1129,8 @@ def _populate_grad_dict(var_to_app_to_idx,
# used to mean undefined, zero, or disconnected.
# We therefore don't allow it because its usage has become
# so muddied.
raise TypeError(('%s.grad returned None for' +
raise TypeError(
('%s.grad returned None for' +
' a gradient term, '
'this is prohibited. Instead of None,'
'return zeros_like(input), disconnected_type(),'
......@@ -1145,18 +1145,18 @@ def _populate_grad_dict(var_to_app_to_idx,
i_shape = orig_ipt_v.shape
t_shape = term_v.shape
if i_shape != t_shape:
raise ValueError("%s.grad returned object of "
raise ValueError(
"%s.grad returned object of "
"shape %s as gradient term on input %d "
"of shape %s" % (node.op, t_shape, i,
i_shape))
"of shape %s" % (node.op, t_shape, i, i_shape))
if not isinstance(term.type,
(NullType, DisconnectedType)):
if term.type.dtype not in theano.tensor.float_dtypes:
raise TypeError(str(node.op) + '.grad illegally '
' returned an integer-valued variable.'
' (Input index %d, dtype %s)' % (i,
term.type.dtype))
' (Input index %d, dtype %s)' % (
i, term.type.dtype))
if only_connected_to_nan[i]:
assert isinstance(term.type, NullType)
......@@ -1241,7 +1241,8 @@ def _populate_grad_dict(var_to_app_to_idx,
term = access_term_cache(node)[idx]
if not isinstance(term, gof.Variable):
raise TypeError("%s.grad returned %s, expected"
raise TypeError(
"%s.grad returned %s, expected"
" Variable instance." % (str(node.op),
type(term)))
......@@ -1255,7 +1256,8 @@ def _populate_grad_dict(var_to_app_to_idx,
continue
if hasattr(var, 'ndim') and term.ndim != var.ndim:
raise ValueError(("%s.grad returned a term with"
raise ValueError(
("%s.grad returned a term with"
" %d dimensions, but %d are required.") % (
str(node.op), term.ndim, var.ndim))
......@@ -1569,7 +1571,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
for i, p in enumerate(pt):
if p.dtype not in ('float32', 'float64'):
raise TypeError(('verify_grad can work only with floating point '
raise TypeError(
('verify_grad can work only with floating point '
'inputs, but input %i has dtype "%s".') % (i, p.dtype))
_type_tol = dict( # relative error tolerances for different types
......@@ -1601,7 +1604,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
on_unused_input='ignore')
return f
tensor_pt = [TensorType(
tensor_pt = [
TensorType(
as_tensor_variable(p).dtype,
as_tensor_variable(p).broadcastable)(name='input %i' % i)
for i, p in enumerate(pt)]
......@@ -1620,7 +1624,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
o_fn_out = o_fn(*[p.copy() for p in pt])
if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list):
raise TypeError('It seems like you are trying to use verify_grad '
raise TypeError(
'It seems like you are trying to use verify_grad '
'on an op or a function which outputs a list: there should'
' be a single (array-like) output instead')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论