提交 bae54705 authored 作者: abergeron's avatar abergeron

Merge pull request #3382 from carriepl/scan_inplace_opt

Scan inplace opt
...@@ -1010,6 +1010,61 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1010,6 +1010,61 @@ class ScanInplaceOptimizer(Optimizer):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler()) fgraph.attach_feature(DestroyHandler())
def attempt_scan_inplace(self, fgraph, node, output_indices):
"""Attempts to replace a Scan node by one which computes the specified
outputs inplace.
Parameters
----------
fgraph : FunctionGraph
Function graph in which to attempt the replacement
node : Apply node
Scan node to replace by an inplace version
output_indices : list of integers
Indices of the outputs to attempt to compute inplace
"""
op = node.op
info = copy.deepcopy(op.info)
if 'destroy_map' not in info:
info['destroy_map'] = OrderedDict()
for out_idx in output_indices:
info['destroy_map'][out_idx] = [out_idx + 1 + op.info['n_seqs']]
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node.inputs)
ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs)
n_outs = len(ls)
for idx in xrange(n_outs):
if ls[idx] in ls[:idx]:
ls[idx] = deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end
new_op = scan_op.Scan(op.inputs,
op.outputs,
info,
typeConstructor=self.typeConstructor)
# Do not call make_node for test_value
new_outs = new_op(*inputs, **dict(return_list=True))
try:
fgraph.replace_all_validate_remove(
list(zip(node.outputs, new_outs)),
remove=[node],
reason='scanOp_make_inplace')
return new_outs[0].owner
except InconsistencyError:
# Failed moving output to be computed inplace
return node
def apply(self, fgraph): def apply(self, fgraph):
nodes = fgraph.toposort() nodes = fgraph.toposort()
...@@ -1018,48 +1073,24 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1018,48 +1073,24 @@ class ScanInplaceOptimizer(Optimizer):
x.op.info['gpu'] == self.gpu_flag and x.op.info['gpu'] == self.gpu_flag and
x.op.info['gpua'] == self.gpua_flag)] x.op.info['gpua'] == self.gpua_flag)]
for scan_idx in xrange(len(scan_nodes)): for scan_idx in xrange(len(scan_nodes)):
node = scan_nodes[scan_idx]
op = node.op # First attempt to make the Scan compute every recurrent output
# inplace. If that fails, go through these outputs individually,
# trying each of them
original_node = scan_nodes[scan_idx]
op = original_node.op
n_outs = (op.info['n_mit_mot'] + n_outs = (op.info['n_mit_mot'] +
op.info['n_mit_sot'] + op.info['n_mit_sot'] +
op.info['n_sit_sot']) op.info['n_sit_sot'])
for pos in xrange(n_outs):
info = copy.deepcopy(op.info)
if not 'destroy_map' in info:
info['destroy_map'] = OrderedDict()
info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']]
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node.inputs)
ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs)
n_outs = len(ls)
for idx in xrange(n_outs):
if ls[idx] in ls[:idx]:
ls[idx] = deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end
new_op = scan_op.Scan(op.inputs,
op.outputs,
info,
typeConstructor=self.typeConstructor)
# Do not call make_node for test_value node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx],
new_outs = new_op(*inputs, **dict(return_list=True)) range(n_outs))
try:
fgraph.replace_all_validate_remove( if node is original_node:
list(zip(node.outputs, new_outs)), # Making the scan compute all recurrent outputs inplace has
remove=[node], # failed. Attempt all recurrent outputs individually.
reason='scanOp_make_inplace') for pos in xrange(n_outs):
op = new_op node = self.attempt_scan_inplace(fgraph, node, [pos])
node = new_outs[0].owner
except InconsistencyError as e:
# Failed moving output to be comptued inplace
pass
class ScanSaveMem(gof.Optimizer): class ScanSaveMem(gof.Optimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论