提交 8ae84a3c authored 作者: carriepl's avatar carriepl

Speed up ScanInplaceOptimizer

上级 1548d7e0
...@@ -1081,29 +1081,63 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1081,29 +1081,63 @@ class ScanInplaceOptimizer(Optimizer):
def apply(self, fgraph): def apply(self, fgraph):
nodes = fgraph.toposort() nodes = fgraph.toposort()[::-1]
scan_nodes = [x for x in nodes scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and if (isinstance(x.op, scan_op.Scan) and
x.op.info['gpu'] == self.gpu_flag and x.op.info['gpu'] == self.gpu_flag and
x.op.info['gpua'] == self.gpua_flag)] x.op.info['gpua'] == self.gpua_flag)]
for scan_idx in xrange(len(scan_nodes)): for scan_idx in xrange(len(scan_nodes)):
# First attempt to make the Scan compute every recurrent output # First attempt to make the Scan compute inplace every recurrent
# inplace. If that fails, go through these outputs individually, # output that seems like it could be computed inplace. If that
# trying each of them # fails, go through these outputs individually, trying each of
# them.
original_node = scan_nodes[scan_idx] original_node = scan_nodes[scan_idx]
op = original_node.op op = original_node.op
n_outs = (op.info['n_mit_mot'] + n_outs = (op.info['n_mit_mot'] +
op.info['n_mit_sot'] + op.info['n_mit_sot'] +
op.info['n_sit_sot']) op.info['n_sit_sot'])
# Generate a list of outputs on which the node could potentially
# operate inplace.
out_indices = []
for out_idx in range(n_outs):
inp_idx = 1 + op.n_seqs + out_idx
input_used_inplace = False
for c in original_node.inputs[inp_idx].clients:
client = c[0]
# Get the indices of this client's inputs on which it
# operates inplace
inplace_inp_indices = []
if hasattr(client.op, 'view_map'):
inplace_inp_indices = sum(client.op.view_map.values(),
inplace_inp_indices)
if hasattr(client.op, 'destroy_map'):
inplace_inp_indices = sum(client.op.destroy_map.values(),
inplace_inp_indices)
for inplace_inp_idx in inplace_inp_indices:
inplace_inp = client.inputs[inplace_inp_idx]
if inplace_inp is original_node.inputs[inp_idx]:
input_used_inplace = True
break
if input_used_inplace:
break
if not input_used_inplace:
out_indices.append(out_idx)
node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx], node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx],
range(n_outs)) out_indices)
if node is original_node: if node is original_node:
# Making the scan compute all recurrent outputs inplace has # Making the scan compute all plausible recurrent outputs
# failed. Attempt all recurrent outputs individually. # inplace has failed. Attempt all plausible recurrent output
for pos in xrange(n_outs): # individually.
for pos in out_indices:
node = self.attempt_scan_inplace(fgraph, node, [pos]) node = self.attempt_scan_inplace(fgraph, node, [pos])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论