提交 fcc7442a authored 作者: Olivier Delalleau's avatar Olivier Delalleau

PEP8 (+ small code simplification)

上级 18f81012
......@@ -869,17 +869,19 @@ def _populate_grad_dict(var_to_app_to_idx,
for o, og in zip(node.outputs, output_grads):
o_dt = getattr(o.type, 'dtype', None)
og_dt = getattr(og.type, 'dtype', None)
if o_dt not in theano.tensor.discrete_dtypes and og_dt and o_dt != og_dt:
if (o_dt not in theano.tensor.discrete_dtypes and
og_dt and o_dt != og_dt):
new_output_grads.append(og.astype(o_dt))
else:
new_output_grads.append(og)
# Make sure that, if new_output_grads[i] has a floating point dtype,
# it is the same dtype as outputs[i]
# Make sure that, if new_output_grads[i] has a floating point
# dtype, it is the same dtype as outputs[i]
for o, ng in zip(node.outputs, new_output_grads):
o_dt = getattr(o.type, 'dtype', None)
ng_dt = getattr(ng.type, 'dtype', None)
if ng_dt is not None and o_dt not in theano.tensor.discrete_dtypes:
if (ng_dt is not None and
o_dt not in theano.tensor.discrete_dtypes):
assert ng_dt == o_dt
# Someone who had obviously not read the Op contract tried
......@@ -890,7 +892,8 @@ def _populate_grad_dict(var_to_app_to_idx,
# 2) Talk to Ian Goodfellow
# (Both of these sources will tell you not to do it)
for ng in new_output_grads:
assert getattr(ng.type, 'dtype', None) not in theano.tensor.discrete_dtypes
assert (getattr(ng.type, 'dtype', None)
not in theano.tensor.discrete_dtypes)
input_grads = node.op.grad(inputs, new_output_grads)
......@@ -908,7 +911,6 @@ def _populate_grad_dict(var_to_app_to_idx,
# Do type checking on the result
# List of bools indicating if each input only has integer outputs
only_connected_to_int = [(True not in
[in_to_out and out_to_cost and not out_int
......@@ -916,7 +918,6 @@ def _populate_grad_dict(var_to_app_to_idx,
zip(in_to_outs, outputs_connected, output_is_int)])
for in_to_outs in connection_pattern]
for i, term in enumerate(input_grads):
# Disallow Nones
......@@ -933,7 +934,6 @@ def _populate_grad_dict(var_to_app_to_idx,
'the grad_undefined or grad_unimplemented helper '
'functions.') % node.op)
if not isinstance(term.type,
(NullType, DisconnectedType)):
if term.type.dtype not in theano.tensor.float_dtypes:
......@@ -973,8 +973,8 @@ def _populate_grad_dict(var_to_app_to_idx,
msg += "evaluate to zeros, but it evaluates to"
msg += "%s."
msg % (str(node.op), str(term), str(type(term)),
i, str(theano.get_scalar_constant_value(term)))
msg % (node.op, term, type(term), i,
theano.get_scalar_constant_value(term))
raise ValueError(msg)
......@@ -1010,8 +1010,6 @@ def _populate_grad_dict(var_to_app_to_idx,
#cache the result
term_dict[node] = input_grads
return term_dict[node]
# populate grad_dict[var] and return it
......@@ -1040,7 +1038,7 @@ def _populate_grad_dict(var_to_app_to_idx,
if isinstance(term.type, DisconnectedType):
continue
if hasattr(var,'ndim') and term.ndim != var.ndim:
if hasattr(var, 'ndim') and term.ndim != var.ndim:
raise ValueError(("%s.grad returned a term with"
" %d dimensions, but %d are required.") % (
str(node.op), term.ndim, var.ndim))
......@@ -1058,8 +1056,8 @@ def _populate_grad_dict(var_to_app_to_idx,
if cost_name is not None and var.name is not None:
grad_dict[var].name = '(d%s/d%s)' % (cost_name, var.name)
else:
# this variable isn't connected to the cost in the computational
# graph
# this variable isn't connected to the cost in the
# computational graph
grad_dict[var] = DisconnectedType()()
# end if cache miss
return grad_dict[var]
......@@ -1068,6 +1066,7 @@ def _populate_grad_dict(var_to_app_to_idx,
return rval
def _float_zeros_like(x):
""" Like zeros_like, but forces the object to have a
a floating point dtype """
......@@ -1599,6 +1598,7 @@ def hessian(cost, wrt, consider_constant=None,
hessians.append(hess)
return format_as(using_list, using_tuple, hessians)
def _is_zero(x):
"""
Returns 'yes', 'no', or 'maybe' indicating whether x
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论