提交 e7732cfc authored 作者: Razvan Pascanu's avatar Razvan Pascanu

PEP8

上级 816d2e75
...@@ -1375,18 +1375,18 @@ class Scan(PureOp): ...@@ -1375,18 +1375,18 @@ class Scan(PureOp):
def compute_gradient(y, g_y): def compute_gradient(y, g_y):
if 'int' in str(g_y.dtype): if 'int' in str(g_y.dtype):
raise TypeError("Gradients may never be integers but g_y " raise TypeError("Gradients may never be integers but g_y "
"has type "+str(g_y.type)) "has type " + str(g_y.type))
wrt = [x for x in theano.gof.graph.inputs([y]) wrt = [x for x in theano.gof.graph.inputs([y])
if x in diff_inputs] if x in diff_inputs]
grads = gradient.grad( grads = gradient.grad(
cost = None, cost=None,
known_grads = {y : g_y }, known_grads={y: g_y},
wrt=wrt, consider_constant=wrt, wrt=wrt, consider_constant=wrt,
disconnected_inputs='ignore', disconnected_inputs='ignore',
return_disconnected='None') return_disconnected='None')
gmp = dict(zip(wrt, grads)) gmp = dict(zip(wrt, grads))
rval = [gmp.get(p, None) for p in diff_inputs] rval = [gmp.get(p, None) for p in diff_inputs]
return rval return rval
dC_dinps_t = [None for inp in diff_inputs] dC_dinps_t = [None for inp in diff_inputs]
disconnected_dC_dinps_t = [True for inp in diff_inputs] disconnected_dC_dinps_t = [True for inp in diff_inputs]
...@@ -1727,7 +1727,7 @@ class Scan(PureOp): ...@@ -1727,7 +1727,7 @@ class Scan(PureOp):
node = outs[0].owner node = outs[0].owner
for idx in xrange(self.n_shared_outs): for idx in xrange(self.n_shared_outs):
disconnected = True disconnected = True
connected_flags = self.connection_pattern(node)[idx+start] connected_flags = self.connection_pattern(node)[idx + start]
for dC_dout, connected in zip(dC_douts, connected_flags): for dC_dout, connected in zip(dC_douts, connected_flags):
if (not isinstance(dC_dout.type, DisconnectedType) and if (not isinstance(dC_dout.type, DisconnectedType) and
connected): connected):
......
...@@ -1342,7 +1342,7 @@ def scan_merge_inouts(node): ...@@ -1342,7 +1342,7 @@ def scan_merge_inouts(node):
return o return o
def map_nitsot_out(i, o, sh, seen): def map_nitsot_out(i, o, sh, seen):
for p,(si, so, ssh) in enumerate(seen): for p, (si, so, ssh) in enumerate(seen):
if equal_computations([i], [si], left, right): if equal_computations([i], [si], left, right):
if equal_computations([sh], [ssh]): if equal_computations([sh], [ssh]):
return so return so
...@@ -1354,7 +1354,7 @@ def scan_merge_inouts(node): ...@@ -1354,7 +1354,7 @@ def scan_merge_inouts(node):
if vsh == vssh: if vsh == vssh:
return so return so
elif vsh > vssh: elif vsh > vssh:
seen[p] = (i,o,sh) seen[p] = (i, o, sh)
return o return o
else: else:
return so[:vsh] return so[:vsh]
...@@ -1384,9 +1384,6 @@ def scan_merge_inouts(node): ...@@ -1384,9 +1384,6 @@ def scan_merge_inouts(node):
na.outer_out_nit_sot, na.outer_out_nit_sot,
shapes)] shapes)]
seen = [] seen = []
na.outer_out_sit_sot = [map_out(i, o, seen) na.outer_out_sit_sot = [map_out(i, o, seen)
for i, o in zip(na.inner_out_sit_sot, for i, o in zip(na.inner_out_sit_sot,
......
...@@ -191,8 +191,6 @@ def get_updates_and_outputs(ls): ...@@ -191,8 +191,6 @@ def get_updates_and_outputs(ls):
this function know how to put it in that order? this function know how to put it in that order?
""" """
def is_outputs(elem): def is_outputs(elem):
if (isinstance(elem, (list, tuple)) and if (isinstance(elem, (list, tuple)) and
all([isinstance(x, theano.Variable) for x in elem])): all([isinstance(x, theano.Variable) for x in elem])):
...@@ -206,7 +204,7 @@ def get_updates_and_outputs(ls): ...@@ -206,7 +204,7 @@ def get_updates_and_outputs(ls):
# Make sure the updates will be applied in a deterministic order # Make sure the updates will be applied in a deterministic order
if not isinstance(elem, gof.python25.OrderedDict): if not isinstance(elem, gof.python25.OrderedDict):
warnings.warn("Expected OrderedDict or OrderedUpdates, got "\ warnings.warn("Expected OrderedDict or OrderedUpdates, got "\
+str(type(elem))+". This can make your script non-" + str(type(elem)) + ". This can make your script non-"
"deterministic.") "deterministic.")
return True return True
# Dictionaries can be given as lists of tuples # Dictionaries can be given as lists of tuples
...@@ -253,7 +251,6 @@ def get_updates_and_outputs(ls): ...@@ -253,7 +251,6 @@ def get_updates_and_outputs(ls):
'values, you can use `tensor.constant` to turn them into ' 'values, you can use `tensor.constant` to turn them into '
'Theano variables.') 'Theano variables.')
if is_outputs(ls): if is_outputs(ls):
return None, _list(ls), OrderedDict() return None, _list(ls), OrderedDict()
if is_updates(ls): if is_updates(ls):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论