提交 8ae84a3c authored 作者: carriepl's avatar carriepl

Speed up ScanInplaceOptimizer

上级 1548d7e0
......@@ -1081,29 +1081,63 @@ class ScanInplaceOptimizer(Optimizer):
def apply(self, fgraph):
nodes = fgraph.toposort()
nodes = fgraph.toposort()[::-1]
scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and
x.op.info['gpu'] == self.gpu_flag and
x.op.info['gpua'] == self.gpua_flag)]
for scan_idx in xrange(len(scan_nodes)):
# First attempt to make the Scan compute every recurrent output
# inplace. If that fails, go through these outputs individually,
# trying each of them
# First attempt to make the Scan compute inplace every recurrent
# output that seems like it could be computed inplace. If that
# fails, go through these outputs individually, trying each of
# them.
original_node = scan_nodes[scan_idx]
op = original_node.op
n_outs = (op.info['n_mit_mot'] +
op.info['n_mit_sot'] +
op.info['n_sit_sot'])
# Generate a list of outputs on which the node could potentially
# operate inplace.
out_indices = []
for out_idx in range(n_outs):
inp_idx = 1 + op.n_seqs + out_idx
input_used_inplace = False
for c in original_node.inputs[inp_idx].clients:
client = c[0]
# Get the indices of this client's inputs on which it
# operates inplace
inplace_inp_indices = []
if hasattr(client.op, 'view_map'):
inplace_inp_indices = sum(client.op.view_map.values(),
inplace_inp_indices)
if hasattr(client.op, 'destroy_map'):
inplace_inp_indices = sum(client.op.destroy_map.values(),
inplace_inp_indices)
for inplace_inp_idx in inplace_inp_indices:
inplace_inp = client.inputs[inplace_inp_idx]
if inplace_inp is original_node.inputs[inp_idx]:
input_used_inplace = True
break
if input_used_inplace:
break
if not input_used_inplace:
out_indices.append(out_idx)
node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx],
range(n_outs))
out_indices)
if node is original_node:
# Making the scan compute all recurrent outputs inplace has
# failed. Attempt all recurrent outputs individually.
for pos in xrange(n_outs):
# Making the scan compute all plausible recurrent outputs
# inplace has failed. Attempt all plausible recurrent output
# individually.
for pos in out_indices:
node = self.attempt_scan_inplace(fgraph, node, [pos])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论