提交 34bec6f3 authored 作者: Frederic Bastien's avatar Frederic Bastien

some pep8

上级 08e5dab2
...@@ -503,7 +503,6 @@ def grad(cost, wrt, consider_constant=None, ...@@ -503,7 +503,6 @@ def grad(cost, wrt, consider_constant=None,
grad_dict[var] = g_var grad_dict[var] = g_var
def handle_disconnected(var): def handle_disconnected(var):
message = ("grad method was asked to compute the gradient " message = ("grad method was asked to compute the gradient "
"with respect to a variable that is not part of " "with respect to a variable that is not part of "
...@@ -520,7 +519,6 @@ def grad(cost, wrt, consider_constant=None, ...@@ -520,7 +519,6 @@ def grad(cost, wrt, consider_constant=None,
"'disconnected_inputs', valid values are " "'disconnected_inputs', valid values are "
"'ignore', 'warn' and 'raise'.") "'ignore', 'warn' and 'raise'.")
# variables that do not influence the cost have zero gradient. # variables that do not influence the cost have zero gradient.
# if wrt is such a variable, populate the grad_dict with this info # if wrt is such a variable, populate the grad_dict with this info
# so that wrt not being in var_to_app_to_idx won't cause an error below # so that wrt not being in var_to_app_to_idx won't cause an error below
...@@ -705,12 +703,12 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False): ...@@ -705,12 +703,12 @@ def subgraph_grad(wrt, end, start=None, cost=None, details=False):
wrt_grads = list(pgrads[k] for k in wrt) wrt_grads = list(pgrads[k] for k in wrt)
end_grads = list(pgrads[k] for k in end) end_grads = list(pgrads[k] for k in end)
if details: if details:
return wrt_grads, end_grads, start_grads, cost_grads return wrt_grads, end_grads, start_grads, cost_grads
return wrt_grads, end_grads return wrt_grads, end_grads
def _node_to_pattern(node): def _node_to_pattern(node):
""" given an apply node, obtain its connection pattern """ given an apply node, obtain its connection pattern
this is just a wrapper around Op.connection_pattern this is just a wrapper around Op.connection_pattern
...@@ -722,30 +720,31 @@ def _node_to_pattern(node): ...@@ -722,30 +720,31 @@ def _node_to_pattern(node):
connection_pattern = node.op.connection_pattern(node) connection_pattern = node.op.connection_pattern(node)
if not isinstance(connection_pattern, list): if not isinstance(connection_pattern, list):
raise TypeError("Op.connection_pattern should return " + \ raise TypeError(
("list of list of bool, but for Op=%s" % node.op) +\ "Op.connection_pattern should return " +
("list of list of bool, but for Op=%s" % node.op) +
"got %s with type %s." % (connection_pattern, "got %s with type %s." % (connection_pattern,
type(connection_pattern))) type(connection_pattern)))
if len(connection_pattern) != len(node.inputs): if len(connection_pattern) != len(node.inputs):
raise ValueError('%s.connection_pattern should have %d' % raise ValueError(
'%s.connection_pattern should have %d' %
(node.op, len(node.inputs)) + ' rows but has %d.' % (node.op, len(node.inputs)) + ' rows but has %d.' %
len(connection_pattern)) len(connection_pattern))
for ii, output_pattern in enumerate(connection_pattern): for ii, output_pattern in enumerate(connection_pattern):
if not isinstance(output_pattern, list): if not isinstance(output_pattern, list):
raise TypeError('%s.connection_pattern should return' % raise TypeError(
node.op + ' a list of lists, but element %d' % ii\ '%s.connection_pattern should return' %
node.op + ' a list of lists, but element %d' % ii
+ 'is %s of type %s.' % (output_pattern, + 'is %s of type %s.' % (output_pattern,
type(output_pattern))) type(output_pattern)))
else: else:
connection_pattern = \ connection_pattern = [[True for output in node.outputs]
[[True for output in node.outputs]
for ipt in node.inputs] for ipt in node.inputs]
assert isinstance(connection_pattern, list) assert isinstance(connection_pattern, list)
assert len(connection_pattern) == len(node.inputs) assert len(connection_pattern) == len(node.inputs)
for ii in xrange(len(node.inputs)): for ii in xrange(len(node.inputs)):
assert isinstance(connection_pattern[ii], list) assert isinstance(connection_pattern[ii], list)
assert len(connection_pattern[ii]) == \ assert len(connection_pattern[ii]) == len(node.outputs)
len(node.outputs)
return connection_pattern return connection_pattern
...@@ -975,7 +974,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -975,7 +974,8 @@ def _populate_grad_dict(var_to_app_to_idx,
for output in output_grads] for output in output_grads]
# List of bools indicating if each input only has NullType outputs # List of bools indicating if each input only has NullType outputs
only_connected_to_nan = [(True not in only_connected_to_nan = [
(True not in
[in_to_out and out_to_cost and not out_nan [in_to_out and out_to_cost and not out_nan
for in_to_out, out_to_cost, out_nan in for in_to_out, out_to_cost, out_nan in
zip(in_to_outs, outputs_connected, ograd_is_nan)]) zip(in_to_outs, outputs_connected, ograd_is_nan)])
...@@ -1021,8 +1021,6 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1021,8 +1021,6 @@ def _populate_grad_dict(var_to_app_to_idx,
inputs = [try_to_copy_if_needed(ipt) for ipt in inputs] inputs = [try_to_copy_if_needed(ipt) for ipt in inputs]
# Build a list of output gradients with the same dtype as # Build a list of output gradients with the same dtype as
# the corresponding output variable. # the corresponding output variable.
# If an output is of a float dtype, we want to cast the # If an output is of a float dtype, we want to cast the
...@@ -1116,7 +1114,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1116,7 +1114,8 @@ def _populate_grad_dict(var_to_app_to_idx,
# Do type checking on the result # Do type checking on the result
# List of bools indicating if each input only has integer outputs # List of bools indicating if each input only has integer outputs
only_connected_to_int = [(True not in only_connected_to_int = [
(True not in
[in_to_out and out_to_cost and not out_int [in_to_out and out_to_cost and not out_int
for in_to_out, out_to_cost, out_int in for in_to_out, out_to_cost, out_int in
zip(in_to_outs, outputs_connected, output_is_int)]) zip(in_to_outs, outputs_connected, output_is_int)])
...@@ -1130,7 +1129,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1130,7 +1129,8 @@ def _populate_grad_dict(var_to_app_to_idx,
# used to mean undefined, zero, or disconnected. # used to mean undefined, zero, or disconnected.
# We therefore don't allow it because its usage has become # We therefore don't allow it because its usage has become
# so muddied. # so muddied.
raise TypeError(('%s.grad returned None for' + raise TypeError(
('%s.grad returned None for' +
' a gradient term, ' ' a gradient term, '
'this is prohibited. Instead of None,' 'this is prohibited. Instead of None,'
'return zeros_like(input), disconnected_type(),' 'return zeros_like(input), disconnected_type(),'
...@@ -1145,18 +1145,18 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1145,18 +1145,18 @@ def _populate_grad_dict(var_to_app_to_idx,
i_shape = orig_ipt_v.shape i_shape = orig_ipt_v.shape
t_shape = term_v.shape t_shape = term_v.shape
if i_shape != t_shape: if i_shape != t_shape:
raise ValueError("%s.grad returned object of " raise ValueError(
"%s.grad returned object of "
"shape %s as gradient term on input %d " "shape %s as gradient term on input %d "
"of shape %s" % (node.op, t_shape, i, "of shape %s" % (node.op, t_shape, i, i_shape))
i_shape))
if not isinstance(term.type, if not isinstance(term.type,
(NullType, DisconnectedType)): (NullType, DisconnectedType)):
if term.type.dtype not in theano.tensor.float_dtypes: if term.type.dtype not in theano.tensor.float_dtypes:
raise TypeError(str(node.op) + '.grad illegally ' raise TypeError(str(node.op) + '.grad illegally '
' returned an integer-valued variable.' ' returned an integer-valued variable.'
' (Input index %d, dtype %s)' % (i, ' (Input index %d, dtype %s)' % (
term.type.dtype)) i, term.type.dtype))
if only_connected_to_nan[i]: if only_connected_to_nan[i]:
assert isinstance(term.type, NullType) assert isinstance(term.type, NullType)
...@@ -1241,7 +1241,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1241,7 +1241,8 @@ def _populate_grad_dict(var_to_app_to_idx,
term = access_term_cache(node)[idx] term = access_term_cache(node)[idx]
if not isinstance(term, gof.Variable): if not isinstance(term, gof.Variable):
raise TypeError("%s.grad returned %s, expected" raise TypeError(
"%s.grad returned %s, expected"
" Variable instance." % (str(node.op), " Variable instance." % (str(node.op),
type(term))) type(term)))
...@@ -1255,7 +1256,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1255,7 +1256,8 @@ def _populate_grad_dict(var_to_app_to_idx,
continue continue
if hasattr(var, 'ndim') and term.ndim != var.ndim: if hasattr(var, 'ndim') and term.ndim != var.ndim:
raise ValueError(("%s.grad returned a term with" raise ValueError(
("%s.grad returned a term with"
" %d dimensions, but %d are required.") % ( " %d dimensions, but %d are required.") % (
str(node.op), term.ndim, var.ndim)) str(node.op), term.ndim, var.ndim))
...@@ -1569,7 +1571,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1569,7 +1571,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
for i, p in enumerate(pt): for i, p in enumerate(pt):
if p.dtype not in ('float32', 'float64'): if p.dtype not in ('float32', 'float64'):
raise TypeError(('verify_grad can work only with floating point ' raise TypeError(
('verify_grad can work only with floating point '
'inputs, but input %i has dtype "%s".') % (i, p.dtype)) 'inputs, but input %i has dtype "%s".') % (i, p.dtype))
_type_tol = dict( # relative error tolerances for different types _type_tol = dict( # relative error tolerances for different types
...@@ -1601,7 +1604,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1601,7 +1604,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
on_unused_input='ignore') on_unused_input='ignore')
return f return f
tensor_pt = [TensorType( tensor_pt = [
TensorType(
as_tensor_variable(p).dtype, as_tensor_variable(p).dtype,
as_tensor_variable(p).broadcastable)(name='input %i' % i) as_tensor_variable(p).broadcastable)(name='input %i' % i)
for i, p in enumerate(pt)] for i, p in enumerate(pt)]
...@@ -1620,7 +1624,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None, ...@@ -1620,7 +1624,8 @@ def verify_grad(fun, pt, n_tests=2, rng=None, eps=None,
o_fn_out = o_fn(*[p.copy() for p in pt]) o_fn_out = o_fn(*[p.copy() for p in pt])
if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list): if isinstance(o_fn_out, tuple) or isinstance(o_fn_out, list):
raise TypeError('It seems like you are trying to use verify_grad ' raise TypeError(
'It seems like you are trying to use verify_grad '
'on an op or a function which outputs a list: there should' 'on an op or a function which outputs a list: there should'
' be a single (array-like) output instead') ' be a single (array-like) output instead')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论