提交 e7732cfc authored 作者: Razvan Pascanu's avatar Razvan Pascanu

PEP8

上级 816d2e75
......@@ -1375,18 +1375,18 @@ class Scan(PureOp):
def compute_gradient(y, g_y):
if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y "
"has type "+str(g_y.type))
"has type " + str(g_y.type))
wrt = [x for x in theano.gof.graph.inputs([y])
wrt = [x for x in theano.gof.graph.inputs([y])
if x in diff_inputs]
grads = gradient.grad(
cost = None,
known_grads = {y : g_y },
grads = gradient.grad(
cost=None,
known_grads={y: g_y},
wrt=wrt, consider_constant=wrt,
disconnected_inputs='ignore',
return_disconnected='None')
gmp = dict(zip(wrt, grads))
rval = [gmp.get(p, None) for p in diff_inputs]
rval = [gmp.get(p, None) for p in diff_inputs]
return rval
dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs]
......@@ -1727,7 +1727,7 @@ class Scan(PureOp):
node = outs[0].owner
for idx in xrange(self.n_shared_outs):
disconnected = True
connected_flags = self.connection_pattern(node)[idx+start]
connected_flags = self.connection_pattern(node)[idx + start]
for dC_dout, connected in zip(dC_douts, connected_flags):
if (not isinstance(dC_dout.type, DisconnectedType) and
connected):
......
......@@ -1342,7 +1342,7 @@ def scan_merge_inouts(node):
return o
def map_nitsot_out(i, o, sh, seen):
for p,(si, so, ssh) in enumerate(seen):
for p, (si, so, ssh) in enumerate(seen):
if equal_computations([i], [si], left, right):
if equal_computations([sh], [ssh]):
return so
......@@ -1354,7 +1354,7 @@ def scan_merge_inouts(node):
if vsh == vssh:
return so
elif vsh > vssh:
seen[p] = (i,o,sh)
seen[p] = (i, o, sh)
return o
else:
return so[:vsh]
......@@ -1384,9 +1384,6 @@ def scan_merge_inouts(node):
na.outer_out_nit_sot,
shapes)]
seen = []
na.outer_out_sit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_sit_sot,
......
......@@ -191,8 +191,6 @@ def get_updates_and_outputs(ls):
this function know how to put it in that order?
"""
def is_outputs(elem):
if (isinstance(elem, (list, tuple)) and
all([isinstance(x, theano.Variable) for x in elem])):
......@@ -206,7 +204,7 @@ def get_updates_and_outputs(ls):
# Make sure the updates will be applied in a deterministic order
if not isinstance(elem, gof.python25.OrderedDict):
warnings.warn("Expected OrderedDict or OrderedUpdates, got "\
+str(type(elem))+". This can make your script non-"
+ str(type(elem)) + ". This can make your script non-"
"deterministic.")
return True
# Dictionaries can be given as lists of tuples
......@@ -253,7 +251,6 @@ def get_updates_and_outputs(ls):
'values, you can use `tensor.constant` to turn them into '
'Theano variables.')
if is_outputs(ls):
return None, _list(ls), OrderedDict()
if is_updates(ls):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论