提交 3f734be4 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

pep8 theano/gradient.py

上级 84a3f6f1
......@@ -266,7 +266,7 @@ def Rop(f, wrt, eval_points):
# we have to make it be wrong for Rop to keep working
# Rop should eventually be upgraded to handle integers
# correctly, the same as grad
y = theano.tensor.cast(y,x.type.dtype)
y = theano.tensor.cast(y, x.type.dtype)
y = x.type.filter_variable(y)
assert x.type == y.type
same_type_eval_points.append(y)
......@@ -493,7 +493,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# Make sure we didn't initialize the grad_dict with any ints
for var in grad_dict:
g = grad_dict[var]
if hasattr(g.type,'dtype'):
if hasattr(g.type, 'dtype'):
assert g.type.dtype.find('float') != -1
rval = _populate_grad_dict(var_to_node_to_idx,
......@@ -509,6 +509,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
rval, = rval
return rval
def _node_to_pattern(node):
""" given an apply node, obtain its connection pattern
this is just a wrapper around Op.connection_pattern
......@@ -516,7 +517,7 @@ def _node_to_pattern(node):
if the method is not implemented
"""
if hasattr(node.op,'connection_pattern'):
if hasattr(node.op, 'connection_pattern'):
connection_pattern = node.op.connection_pattern(node)
if not isinstance(connection_pattern, list):
......@@ -538,7 +539,7 @@ def _node_to_pattern(node):
connection_pattern = \
[[True for output in node.outputs]
for ipt in node.inputs]
assert isinstance(connection_pattern,list)
assert isinstance(connection_pattern, list)
assert len(connection_pattern) == len(node.inputs)
for ii in xrange(len(node.inputs)):
assert isinstance(connection_pattern[ii], list)
......@@ -546,6 +547,7 @@ def _node_to_pattern(node):
len(node.outputs)
return connection_pattern
def _populate_var_to_node_to_idx(outputs, wrt):
"""
Common code shared between grad and grad_sources_inputs
......@@ -583,7 +585,6 @@ def _populate_var_to_node_to_idx(outputs, wrt):
# connection_pattern)
accounted_for = set([])
def account_for(var):
if var in accounted_for:
return
......@@ -693,16 +694,16 @@ def _populate_grad_dict(var_to_node_to_idx,
output_grads = [access_grad_cache(var) for var in node.outputs]
# list of bools indicating if each output is connected to the cost
outputs_connected = [ not isinstance(g.type, DisconnectedType)
for g in output_grads ]
outputs_connected = [not isinstance(g.type, DisconnectedType)
for g in output_grads]
connection_pattern = _node_to_pattern(node)
# list of bools indicating if each input is connected to the cost
inputs_connected = [
(True in [ input_to_output and output_to_cost for
(True in [input_to_output and output_to_cost for
input_to_output, output_to_cost in
zip(input_to_outputs, outputs_connected) ]) for
zip(input_to_outputs, outputs_connected)]) for
input_to_outputs in connection_pattern
]
......@@ -752,16 +753,16 @@ def _populate_grad_dict(var_to_node_to_idx,
# Do type checking on the result
#List of bools indicating if each output is an integer dtype
output_is_int = [ hasattr(output.type,'dtype') and
output_is_int = [hasattr(output.type, 'dtype') and
output.type.dtype.find('int') != -1
for output in node.outputs]
#List of bools indicating if each input only has integer outputs
only_connected_to_int = [ (True not in
[ in_to_out and out_to_cost and not out_int
only_connected_to_int = [(True not in
[in_to_out and out_to_cost and not out_int
for in_to_out, out_to_cost, out_int in
zip(in_to_outs, outputs_connected, output_is_int) ])
for in_to_outs in connection_pattern ]
zip(in_to_outs, outputs_connected, output_is_int)])
for in_to_outs in connection_pattern]
for i, term in enumerate(input_grads):
......@@ -780,9 +781,9 @@ def _populate_grad_dict(var_to_node_to_idx,
'functions.') % node.op)
if not isinstance(term.type,
(NullType,DisconnectedType)):
(NullType, DisconnectedType)):
if term.type.dtype.find('float') == -1:
raise TypeError(str(node.op)+'.grad illegally '
raise TypeError(str(node.op) + '.grad illegally '
' returned an integer-valued variable.'
' (Input index %d, dtype %s)' % (i,
term.type.dtype))
......@@ -851,7 +852,6 @@ def _populate_grad_dict(var_to_node_to_idx,
raise ValueError(msg)
#Check that op.connection_pattern matches the connectivity
#logic driving the op.grad method
for i, packed in \
......@@ -872,7 +872,7 @@ def _populate_grad_dict(var_to_node_to_idx,
msg = "%s.grad returned DisconnectedType for input"
msg += " %d."
msg = msg % (str(node.op), i)
if hasattr(node.op,'connection_pattern'):
if hasattr(node.op, 'connection_pattern'):
msg += ' Its connection_pattern method does not'
msg += ' allow this.'
raise TypeError(msg)
......@@ -917,7 +917,7 @@ def _populate_grad_dict(var_to_node_to_idx,
if len(terms) > 0:
# the next line is like sum(terms) but doesn't add an
# extraneous TensorConstant(0)
grad_dict[var] = reduce(lambda x,y: x+y, terms)
grad_dict[var] = reduce(lambda x, y: x + y, terms)
else:
grad_dict[var] = DisconnectedType()()
......@@ -1029,6 +1029,7 @@ def grad_sources_inputs(sources, graph_inputs):
return grad_dict
def _float_zeros_like(x):
""" Like zeros_like, but forces the object to have a
a floating point dtype """
......@@ -1040,6 +1041,7 @@ def _float_zeros_like(x):
return rval.astype(theano.config.floatX)
def _float_ones_like(x):
""" Like ones_like, but forces the object to have a
floating point dtype """
......@@ -1051,6 +1053,7 @@ def _float_ones_like(x):
return rval.astype(theano.config.floatX)
class numeric_grad(object):
"""
Compute the numeric derivative of a scalar-valued function at a particular
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论