提交 9755603b authored 作者: Frederic's avatar Frederic

pep8

上级 aca35acc
...@@ -327,8 +327,8 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -327,8 +327,8 @@ class PushOutSeqScan(gof.Optimizer):
fgraph.attach_feature(gof.toolbox.ReplaceValidate()) fgraph.attach_feature(gof.toolbox.ReplaceValidate())
def apply(self, fgraph): def apply(self, fgraph):
nodelist = [x for x in fgraph.toposort() if isinstance(x.op, nodelist = [x for x in fgraph.toposort()
scan_op.Scan)] if isinstance(x.op, scan_op.Scan)]
for node in nodelist: for node in nodelist:
self.process_node(fgraph, node) self.process_node(fgraph, node)
...@@ -376,7 +376,7 @@ class PushOutSeqScan(gof.Optimizer): ...@@ -376,7 +376,7 @@ class PushOutSeqScan(gof.Optimizer):
elif x in inner_seqs: elif x in inner_seqs:
outside_ins += [outer_seqs[inner_seqs.index(x)]] outside_ins += [outer_seqs[inner_seqs.index(x)]]
elif x in to_replace: elif x in to_replace:
outside_ins += [replace_with_out[\ outside_ins += [replace_with_out[
to_replace.index(x)]] to_replace.index(x)]]
elif isinstance(x, theano.Constant): elif isinstance(x, theano.Constant):
outside_ins += [x.clone()] outside_ins += [x.clone()]
...@@ -847,9 +847,8 @@ class ScanSaveMem(gof.Optimizer): ...@@ -847,9 +847,8 @@ class ScanSaveMem(gof.Optimizer):
nw_inputs[0] = nw_steps nw_inputs[0] = nw_steps
# 3.2 check orphane outputs to see if we can eliminate any # 3.2 check orphane outputs to see if we can eliminate any
required, not_required = \ required, not_required = scan_utils.scan_can_remove_outs(
scan_utils.scan_can_remove_outs(node.op, node.op, orphane_outs)
orphane_outs)
# 3.3. compose replace pairs for those nodes that need not # 3.3. compose replace pairs for those nodes that need not
# to store everything in memory ( or ar orphane and required # to store everything in memory ( or ar orphane and required
# by the inner function .. ) # by the inner function .. )
...@@ -1011,9 +1010,8 @@ class ScanSaveMem(gof.Optimizer): ...@@ -1011,9 +1010,8 @@ class ScanSaveMem(gof.Optimizer):
position = (cnf_slice[0] - nw_steps - position = (cnf_slice[0] - nw_steps -
init_l[pos] + store_steps[pos]) init_l[pos] + store_steps[pos])
nw_slice = (sanitize(position),) + \ nw_slice = (sanitize(position),) + tuple(
tuple(old_slices[1:]) old_slices[1:])
subtens = tensor.Subtensor(nw_slice) subtens = tensor.Subtensor(nw_slice)
sl_ins = tensor.Subtensor.collapse( sl_ins = tensor.Subtensor.collapse(
nw_slice, nw_slice,
...@@ -1592,10 +1590,8 @@ class PushOutDot1(gof.Optimizer): ...@@ -1592,10 +1590,8 @@ class PushOutDot1(gof.Optimizer):
old = node.outputs[pos].clients[0][0].outputs[0] old = node.outputs[pos].clients[0][0].outputs[0]
old_new.append((old, new_out)) old_new.append((old, new_out))
old_new += zip(node.outputs[pos+1:], new_outs[pos:]) old_new += zip(node.outputs[pos+1:], new_outs[pos:])
fgraph.replace_all_validate_remove(old_new, fgraph.replace_all_validate_remove(
remove = [node], old_new, remove=[node], reason='scan_pushout_dot1')
reason='scan_pushout_dot1')
# I've added an equilibrium because later scan optimization in the sequence # I've added an equilibrium because later scan optimization in the sequence
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论