提交 3f734be4 authored 作者: Ian Goodfellow's avatar Ian Goodfellow

pep8 theano/gradient.py

上级 84a3f6f1
...@@ -266,7 +266,7 @@ def Rop(f, wrt, eval_points): ...@@ -266,7 +266,7 @@ def Rop(f, wrt, eval_points):
# we have to make it be wrong for Rop to keep working # we have to make it be wrong for Rop to keep working
# Rop should eventually be upgraded to handle integers # Rop should eventually be upgraded to handle integers
# correctly, the same as grad # correctly, the same as grad
y = theano.tensor.cast(y,x.type.dtype) y = theano.tensor.cast(y, x.type.dtype)
y = x.type.filter_variable(y) y = x.type.filter_variable(y)
assert x.type == y.type assert x.type == y.type
same_type_eval_points.append(y) same_type_eval_points.append(y)
...@@ -493,7 +493,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -493,7 +493,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
# Make sure we didn't initialize the grad_dict with any ints # Make sure we didn't initialize the grad_dict with any ints
for var in grad_dict: for var in grad_dict:
g = grad_dict[var] g = grad_dict[var]
if hasattr(g.type,'dtype'): if hasattr(g.type, 'dtype'):
assert g.type.dtype.find('float') != -1 assert g.type.dtype.find('float') != -1
rval = _populate_grad_dict(var_to_node_to_idx, rval = _populate_grad_dict(var_to_node_to_idx,
...@@ -509,6 +509,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None, ...@@ -509,6 +509,7 @@ def grad(cost, wrt, g_cost=None, consider_constant=None,
rval, = rval rval, = rval
return rval return rval
def _node_to_pattern(node): def _node_to_pattern(node):
""" given an apply node, obtain its connection pattern """ given an apply node, obtain its connection pattern
this is just a wrapper around Op.connection_pattern this is just a wrapper around Op.connection_pattern
...@@ -516,7 +517,7 @@ def _node_to_pattern(node): ...@@ -516,7 +517,7 @@ def _node_to_pattern(node):
if the method is not implemented if the method is not implemented
""" """
if hasattr(node.op,'connection_pattern'): if hasattr(node.op, 'connection_pattern'):
connection_pattern = node.op.connection_pattern(node) connection_pattern = node.op.connection_pattern(node)
if not isinstance(connection_pattern, list): if not isinstance(connection_pattern, list):
...@@ -538,7 +539,7 @@ def _node_to_pattern(node): ...@@ -538,7 +539,7 @@ def _node_to_pattern(node):
connection_pattern = \ connection_pattern = \
[[True for output in node.outputs] [[True for output in node.outputs]
for ipt in node.inputs] for ipt in node.inputs]
assert isinstance(connection_pattern,list) assert isinstance(connection_pattern, list)
assert len(connection_pattern) == len(node.inputs) assert len(connection_pattern) == len(node.inputs)
for ii in xrange(len(node.inputs)): for ii in xrange(len(node.inputs)):
assert isinstance(connection_pattern[ii], list) assert isinstance(connection_pattern[ii], list)
...@@ -546,6 +547,7 @@ def _node_to_pattern(node): ...@@ -546,6 +547,7 @@ def _node_to_pattern(node):
len(node.outputs) len(node.outputs)
return connection_pattern return connection_pattern
def _populate_var_to_node_to_idx(outputs, wrt): def _populate_var_to_node_to_idx(outputs, wrt):
""" """
Common code shared between grad and grad_sources_inputs Common code shared between grad and grad_sources_inputs
...@@ -583,7 +585,6 @@ def _populate_var_to_node_to_idx(outputs, wrt): ...@@ -583,7 +585,6 @@ def _populate_var_to_node_to_idx(outputs, wrt):
# connection_pattern) # connection_pattern)
accounted_for = set([]) accounted_for = set([])
def account_for(var): def account_for(var):
if var in accounted_for: if var in accounted_for:
return return
...@@ -693,16 +694,16 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -693,16 +694,16 @@ def _populate_grad_dict(var_to_node_to_idx,
output_grads = [access_grad_cache(var) for var in node.outputs] output_grads = [access_grad_cache(var) for var in node.outputs]
# list of bools indicating if each output is connected to the cost # list of bools indicating if each output is connected to the cost
outputs_connected = [ not isinstance(g.type, DisconnectedType) outputs_connected = [not isinstance(g.type, DisconnectedType)
for g in output_grads ] for g in output_grads]
connection_pattern = _node_to_pattern(node) connection_pattern = _node_to_pattern(node)
# list of bools indicating if each input is connected to the cost # list of bools indicating if each input is connected to the cost
inputs_connected = [ inputs_connected = [
(True in [ input_to_output and output_to_cost for (True in [input_to_output and output_to_cost for
input_to_output, output_to_cost in input_to_output, output_to_cost in
zip(input_to_outputs, outputs_connected) ]) for zip(input_to_outputs, outputs_connected)]) for
input_to_outputs in connection_pattern input_to_outputs in connection_pattern
] ]
...@@ -752,16 +753,16 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -752,16 +753,16 @@ def _populate_grad_dict(var_to_node_to_idx,
# Do type checking on the result # Do type checking on the result
#List of bools indicating if each output is an integer dtype #List of bools indicating if each output is an integer dtype
output_is_int = [ hasattr(output.type,'dtype') and output_is_int = [hasattr(output.type, 'dtype') and
output.type.dtype.find('int') != -1 output.type.dtype.find('int') != -1
for output in node.outputs] for output in node.outputs]
#List of bools indicating if each input only has integer outputs #List of bools indicating if each input only has integer outputs
only_connected_to_int = [ (True not in only_connected_to_int = [(True not in
[ in_to_out and out_to_cost and not out_int [in_to_out and out_to_cost and not out_int
for in_to_out, out_to_cost, out_int in for in_to_out, out_to_cost, out_int in
zip(in_to_outs, outputs_connected, output_is_int) ]) zip(in_to_outs, outputs_connected, output_is_int)])
for in_to_outs in connection_pattern ] for in_to_outs in connection_pattern]
for i, term in enumerate(input_grads): for i, term in enumerate(input_grads):
...@@ -780,9 +781,9 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -780,9 +781,9 @@ def _populate_grad_dict(var_to_node_to_idx,
'functions.') % node.op) 'functions.') % node.op)
if not isinstance(term.type, if not isinstance(term.type,
(NullType,DisconnectedType)): (NullType, DisconnectedType)):
if term.type.dtype.find('float') == -1: if term.type.dtype.find('float') == -1:
raise TypeError(str(node.op)+'.grad illegally ' raise TypeError(str(node.op) + '.grad illegally '
' returned an integer-valued variable.' ' returned an integer-valued variable.'
' (Input index %d, dtype %s)' % (i, ' (Input index %d, dtype %s)' % (i,
term.type.dtype)) term.type.dtype))
...@@ -851,7 +852,6 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -851,7 +852,6 @@ def _populate_grad_dict(var_to_node_to_idx,
raise ValueError(msg) raise ValueError(msg)
#Check that op.connection_pattern matches the connectivity #Check that op.connection_pattern matches the connectivity
#logic driving the op.grad method #logic driving the op.grad method
for i, packed in \ for i, packed in \
...@@ -872,7 +872,7 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -872,7 +872,7 @@ def _populate_grad_dict(var_to_node_to_idx,
msg = "%s.grad returned DisconnectedType for input" msg = "%s.grad returned DisconnectedType for input"
msg += " %d." msg += " %d."
msg = msg % (str(node.op), i) msg = msg % (str(node.op), i)
if hasattr(node.op,'connection_pattern'): if hasattr(node.op, 'connection_pattern'):
msg += ' Its connection_pattern method does not' msg += ' Its connection_pattern method does not'
msg += ' allow this.' msg += ' allow this.'
raise TypeError(msg) raise TypeError(msg)
...@@ -917,7 +917,7 @@ def _populate_grad_dict(var_to_node_to_idx, ...@@ -917,7 +917,7 @@ def _populate_grad_dict(var_to_node_to_idx,
if len(terms) > 0: if len(terms) > 0:
# the next line is like sum(terms) but doesn't add an # the next line is like sum(terms) but doesn't add an
# extraneous TensorConstant(0) # extraneous TensorConstant(0)
grad_dict[var] = reduce(lambda x,y: x+y, terms) grad_dict[var] = reduce(lambda x, y: x + y, terms)
else: else:
grad_dict[var] = DisconnectedType()() grad_dict[var] = DisconnectedType()()
...@@ -1029,6 +1029,7 @@ def grad_sources_inputs(sources, graph_inputs): ...@@ -1029,6 +1029,7 @@ def grad_sources_inputs(sources, graph_inputs):
return grad_dict return grad_dict
def _float_zeros_like(x): def _float_zeros_like(x):
""" Like zeros_like, but forces the object to have a """ Like zeros_like, but forces the object to have a
a floating point dtype """ a floating point dtype """
...@@ -1040,6 +1041,7 @@ def _float_zeros_like(x): ...@@ -1040,6 +1041,7 @@ def _float_zeros_like(x):
return rval.astype(theano.config.floatX) return rval.astype(theano.config.floatX)
def _float_ones_like(x): def _float_ones_like(x):
""" Like ones_like, but forces the object to have a """ Like ones_like, but forces the object to have a
floating point dtype """ floating point dtype """
...@@ -1051,6 +1053,7 @@ def _float_ones_like(x): ...@@ -1051,6 +1053,7 @@ def _float_ones_like(x):
return rval.astype(theano.config.floatX) return rval.astype(theano.config.floatX)
class numeric_grad(object): class numeric_grad(object):
""" """
Compute the numeric derivative of a scalar-valued function at a particular Compute the numeric derivative of a scalar-valued function at a particular
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论