提交 bae54705 authored 作者: abergeron's avatar abergeron

Merge pull request #3382 from carriepl/scan_inplace_opt

Scan inplace opt
...@@ -1010,25 +1010,29 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1010,25 +1010,29 @@ class ScanInplaceOptimizer(Optimizer):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler()) fgraph.attach_feature(DestroyHandler())
def apply(self, fgraph): def attempt_scan_inplace(self, fgraph, node, output_indices):
"""Attempts to replace a Scan node by one which computes the specified
outputs inplace.
Parameters
----------
fgraph : FunctionGraph
Function graph in which to attempt the replacement
node : Apply node
Scan node to replace by an inplace version
output_indices : list of integers
Indices of the outputs to attempt to compute inplace
"""
nodes = fgraph.toposort()
scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and
x.op.info['gpu'] == self.gpu_flag and
x.op.info['gpua'] == self.gpua_flag)]
for scan_idx in xrange(len(scan_nodes)):
node = scan_nodes[scan_idx]
op = node.op op = node.op
n_outs = (op.info['n_mit_mot'] +
op.info['n_mit_sot'] +
op.info['n_sit_sot'])
for pos in xrange(n_outs):
info = copy.deepcopy(op.info) info = copy.deepcopy(op.info)
if not 'destroy_map' in info: if 'destroy_map' not in info:
info['destroy_map'] = OrderedDict() info['destroy_map'] = OrderedDict()
info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']] for out_idx in output_indices:
info['destroy_map'][out_idx] = [out_idx + 1 + op.info['n_seqs']]
# inputs corresponding to sequences and n_steps # inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs] ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node.inputs) ls = op.outer_mitmot(node.inputs)
...@@ -1037,6 +1041,7 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1037,6 +1041,7 @@ class ScanInplaceOptimizer(Optimizer):
ls_end = op.outer_shared(node.inputs) ls_end = op.outer_shared(node.inputs)
ls_end += op.outer_nitsot(node.inputs) ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs) ls_end += op.outer_non_seqs(node.inputs)
n_outs = len(ls) n_outs = len(ls)
for idx in xrange(n_outs): for idx in xrange(n_outs):
if ls[idx] in ls[:idx]: if ls[idx] in ls[:idx]:
...@@ -1055,11 +1060,37 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1055,11 +1060,37 @@ class ScanInplaceOptimizer(Optimizer):
list(zip(node.outputs, new_outs)), list(zip(node.outputs, new_outs)),
remove=[node], remove=[node],
reason='scanOp_make_inplace') reason='scanOp_make_inplace')
op = new_op return new_outs[0].owner
node = new_outs[0].owner except InconsistencyError:
except InconsistencyError as e: # Failed moving output to be computed inplace
# Failed moving output to be comptued inplace return node
pass
def apply(self, fgraph):
nodes = fgraph.toposort()
scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and
x.op.info['gpu'] == self.gpu_flag and
x.op.info['gpua'] == self.gpua_flag)]
for scan_idx in xrange(len(scan_nodes)):
# First attempt to make the Scan compute every recurrent output
# inplace. If that fails, go through these outputs individually,
# trying each of them
original_node = scan_nodes[scan_idx]
op = original_node.op
n_outs = (op.info['n_mit_mot'] +
op.info['n_mit_sot'] +
op.info['n_sit_sot'])
node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx],
range(n_outs))
if node is original_node:
# Making the scan compute all recurrent outputs inplace has
# failed. Attempt all recurrent outputs individually.
for pos in xrange(n_outs):
node = self.attempt_scan_inplace(fgraph, node, [pos])
class ScanSaveMem(gof.Optimizer): class ScanSaveMem(gof.Optimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论