提交 fcc7442a authored 作者: Olivier Delalleau's avatar Olivier Delalleau

PEP8 (+ small code simplification)

上级 18f81012
...@@ -869,17 +869,19 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -869,17 +869,19 @@ def _populate_grad_dict(var_to_app_to_idx,
for o, og in zip(node.outputs, output_grads): for o, og in zip(node.outputs, output_grads):
o_dt = getattr(o.type, 'dtype', None) o_dt = getattr(o.type, 'dtype', None)
og_dt = getattr(og.type, 'dtype', None) og_dt = getattr(og.type, 'dtype', None)
if o_dt not in theano.tensor.discrete_dtypes and og_dt and o_dt != og_dt: if (o_dt not in theano.tensor.discrete_dtypes and
og_dt and o_dt != og_dt):
new_output_grads.append(og.astype(o_dt)) new_output_grads.append(og.astype(o_dt))
else: else:
new_output_grads.append(og) new_output_grads.append(og)
# Make sure that, if new_output_grads[i] has a floating point dtype, # Make sure that, if new_output_grads[i] has a floating point
# it is the same dtype as outputs[i] # dtype, it is the same dtype as outputs[i]
for o, ng in zip(node.outputs, new_output_grads): for o, ng in zip(node.outputs, new_output_grads):
o_dt = getattr(o.type, 'dtype', None) o_dt = getattr(o.type, 'dtype', None)
ng_dt = getattr(ng.type, 'dtype', None) ng_dt = getattr(ng.type, 'dtype', None)
if ng_dt is not None and o_dt not in theano.tensor.discrete_dtypes: if (ng_dt is not None and
o_dt not in theano.tensor.discrete_dtypes):
assert ng_dt == o_dt assert ng_dt == o_dt
# Someone who had obviously not read the Op contract tried # Someone who had obviously not read the Op contract tried
...@@ -890,7 +892,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -890,7 +892,8 @@ def _populate_grad_dict(var_to_app_to_idx,
# 2) Talk to Ian Goodfellow # 2) Talk to Ian Goodfellow
# (Both of these sources will tell you not to do it) # (Both of these sources will tell you not to do it)
for ng in new_output_grads: for ng in new_output_grads:
assert getattr(ng.type, 'dtype', None) not in theano.tensor.discrete_dtypes assert (getattr(ng.type, 'dtype', None)
not in theano.tensor.discrete_dtypes)
input_grads = node.op.grad(inputs, new_output_grads) input_grads = node.op.grad(inputs, new_output_grads)
...@@ -908,7 +911,6 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -908,7 +911,6 @@ def _populate_grad_dict(var_to_app_to_idx,
# Do type checking on the result # Do type checking on the result
# List of bools indicating if each input only has integer outputs # List of bools indicating if each input only has integer outputs
only_connected_to_int = [(True not in only_connected_to_int = [(True not in
[in_to_out and out_to_cost and not out_int [in_to_out and out_to_cost and not out_int
...@@ -916,7 +918,6 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -916,7 +918,6 @@ def _populate_grad_dict(var_to_app_to_idx,
zip(in_to_outs, outputs_connected, output_is_int)]) zip(in_to_outs, outputs_connected, output_is_int)])
for in_to_outs in connection_pattern] for in_to_outs in connection_pattern]
for i, term in enumerate(input_grads): for i, term in enumerate(input_grads):
# Disallow Nones # Disallow Nones
...@@ -933,7 +934,6 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -933,7 +934,6 @@ def _populate_grad_dict(var_to_app_to_idx,
'the grad_undefined or grad_unimplemented helper ' 'the grad_undefined or grad_unimplemented helper '
'functions.') % node.op) 'functions.') % node.op)
if not isinstance(term.type, if not isinstance(term.type,
(NullType, DisconnectedType)): (NullType, DisconnectedType)):
if term.type.dtype not in theano.tensor.float_dtypes: if term.type.dtype not in theano.tensor.float_dtypes:
...@@ -973,8 +973,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -973,8 +973,8 @@ def _populate_grad_dict(var_to_app_to_idx,
msg += "evaluate to zeros, but it evaluates to" msg += "evaluate to zeros, but it evaluates to"
msg += "%s." msg += "%s."
msg % (str(node.op), str(term), str(type(term)), msg % (node.op, term, type(term), i,
i, str(theano.get_scalar_constant_value(term))) theano.get_scalar_constant_value(term))
raise ValueError(msg) raise ValueError(msg)
...@@ -1010,8 +1010,6 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1010,8 +1010,6 @@ def _populate_grad_dict(var_to_app_to_idx,
#cache the result #cache the result
term_dict[node] = input_grads term_dict[node] = input_grads
return term_dict[node] return term_dict[node]
# populate grad_dict[var] and return it # populate grad_dict[var] and return it
...@@ -1040,7 +1038,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1040,7 +1038,7 @@ def _populate_grad_dict(var_to_app_to_idx,
if isinstance(term.type, DisconnectedType): if isinstance(term.type, DisconnectedType):
continue continue
if hasattr(var,'ndim') and term.ndim != var.ndim: if hasattr(var, 'ndim') and term.ndim != var.ndim:
raise ValueError(("%s.grad returned a term with" raise ValueError(("%s.grad returned a term with"
" %d dimensions, but %d are required.") % ( " %d dimensions, but %d are required.") % (
str(node.op), term.ndim, var.ndim)) str(node.op), term.ndim, var.ndim))
...@@ -1058,8 +1056,8 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1058,8 +1056,8 @@ def _populate_grad_dict(var_to_app_to_idx,
if cost_name is not None and var.name is not None: if cost_name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name) grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name)
else: else:
# this variable isn't connected to the cost in the computational # this variable isn't connected to the cost in the
# graph # computational graph
grad_dict[var] = DisconnectedType()() grad_dict[var] = DisconnectedType()()
# end if cache miss # end if cache miss
return grad_dict[var] return grad_dict[var]
...@@ -1068,6 +1066,7 @@ def _populate_grad_dict(var_to_app_to_idx, ...@@ -1068,6 +1066,7 @@ def _populate_grad_dict(var_to_app_to_idx,
return rval return rval
def _float_zeros_like(x): def _float_zeros_like(x):
""" Like zeros_like, but forces the object to have a """ Like zeros_like, but forces the object to have a
a floating point dtype """ a floating point dtype """
...@@ -1599,6 +1598,7 @@ def hessian(cost, wrt, consider_constant=None, ...@@ -1599,6 +1598,7 @@ def hessian(cost, wrt, consider_constant=None,
hessians.append(hess) hessians.append(hess)
return format_as(using_list, using_tuple, hessians) return format_as(using_list, using_tuple, hessians)
def _is_zero(x): def _is_zero(x):
""" """
Returns 'yes', 'no', or 'maybe' indicating whether x Returns 'yes', 'no', or 'maybe' indicating whether x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论