提交 087493ec authored 作者: Razvan Pascanu's avatar Razvan Pascanu

PEP8

上级 fdd6f5ab
...@@ -1356,6 +1356,7 @@ class Scan(PureOp): ...@@ -1356,6 +1356,7 @@ class Scan(PureOp):
# Applying Floyd-Warshall to find all paths connecting inputs to # Applying Floyd-Warshall to find all paths connecting inputs to
# outputs. Note that if `x` is an input to `y_t` and `y_tm1` is an # outputs. Note that if `x` is an input to `y_t` and `y_tm1` is an
# input to `z_t` then `x` is an input to `z_t`. # input to `z_t` then `x` is an input to `z_t`.
n_outs = len(node.outputs) n_outs = len(node.outputs)
for steps in xrange(n_outs): for steps in xrange(n_outs):
for iidx in xrange(n_outs): for iidx in xrange(n_outs):
...@@ -1446,13 +1447,15 @@ class Scan(PureOp): ...@@ -1446,13 +1447,15 @@ class Scan(PureOp):
odx = get_out_idx(self_outputs.index(y)) odx = get_out_idx(self_outputs.index(y))
wrt = [x for x in theano.gof.graph.inputs([y]) wrt = [x for x in theano.gof.graph.inputs([y])
if (x in diff_inputs) and if (x in diff_inputs) and
connection_pattern[get_inp_idx(self_inputs.index(x))][odx]] (connection_pattern[
get_inp_idx(self_inputs.index(x))][odx])]
grads = gradient.grad( grads = gradient.grad(
cost=None, cost=None,
known_grads={y: g_y}, known_grads={y: g_y},
wrt=wrt, consider_constant=wrt, wrt=wrt,
disconnected_inputs='ignore', consider_constant=wrt,
return_disconnected='None') disconnected_inputs='ignore',
return_disconnected='None')
gmp = dict(zip(wrt, grads)) gmp = dict(zip(wrt, grads))
rval = [gmp.get(p, None) for p in diff_inputs] rval = [gmp.get(p, None) for p in diff_inputs]
return rval return rval
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论