提交 d676f50b authored 作者: carriepl's avatar carriepl

Speed up ScanInplaceOptimizer

上级 e25b4c7c
...@@ -709,7 +709,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -709,7 +709,7 @@ class PushOutScanOutput(gof.Optimizer):
The Dot product is pushed out of the scan and its inputs are The Dot product is pushed out of the scan and its inputs are
now the original matrix and a new matrix obtained by now the original matrix and a new matrix obtained by
concatenating the vectors into a matrix. concatenating the vectors into a matrix.
""" """
# Ensure that the output of the Dot is used in the outer # Ensure that the output of the Dot is used in the outer
# graph to avoid apply the optimization needlessly # graph to avoid apply the optimization needlessly
...@@ -723,7 +723,7 @@ class PushOutScanOutput(gof.Optimizer): ...@@ -723,7 +723,7 @@ class PushOutScanOutput(gof.Optimizer):
non-sequence input to scan and that the other input is a non-sequence input to scan and that the other input is a
vector and either an sequence input to scan or the result vector and either an sequence input to scan or the result
of computation in the inner function of scan. of computation in the inner function of scan.
""" """
valid_inputs = False valid_inputs = False
idx_matrix_input = -1 idx_matrix_input = -1
...@@ -1013,6 +1013,62 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1013,6 +1013,62 @@ class ScanInplaceOptimizer(Optimizer):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler()) fgraph.attach_feature(DestroyHandler())
def attempt_scan_inplace(self, fgraph, node, output_indices):
"""Attempts to replace a Scan node by one which computes the specified
outputs inplace.
Parameters
----------
fgraph : FunctionGraph
Function graph in which to attempt the replacement
node : Apply node
Scan node to replace by an inplace version
output_indices : list of integers
Indices of the outputs to attempt to compute inplace
"""
op = node.op
info = copy.deepcopy(op.info)
if not 'destroy_map' in info:
info['destroy_map'] = OrderedDict()
for out_idx in output_indices:
info['destroy_map'][out_idx] = [out_idx + 1 + op.info['n_seqs']]
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node.inputs)
ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs)
n_outs = len(ls)
for idx in xrange(n_outs):
if ls[idx] in ls[:idx]:
ls[idx] = deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end
new_op = scan_op.Scan(op.inputs,
op.outputs,
info,
typeConstructor=self.typeConstructor)
# Do not call make_node for test_value
new_outs = new_op(*inputs, **dict(return_list=True))
try:
fgraph.replace_all_validate_remove(
list(zip(node.outputs, new_outs)),
remove=[node],
reason='scanOp_make_inplace')
return new_outs[0].owner
except InconsistencyError as e:
# Failed moving output to be computed inplace
return node
def apply(self, fgraph): def apply(self, fgraph):
nodes = fgraph.toposort() nodes = fgraph.toposort()
...@@ -1021,48 +1077,24 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1021,48 +1077,24 @@ class ScanInplaceOptimizer(Optimizer):
x.op.info['gpu'] == self.gpu_flag and x.op.info['gpu'] == self.gpu_flag and
x.op.info['gpua'] == self.gpua_flag)] x.op.info['gpua'] == self.gpua_flag)]
for scan_idx in xrange(len(scan_nodes)): for scan_idx in xrange(len(scan_nodes)):
node = scan_nodes[scan_idx]
op = node.op # First attempt to make the Scan compute every recurrent output
# inplace. If that fails, go through these outputs individually,
# trying each of them
original_node = scan_nodes[scan_idx]
op = original_node.op
n_outs = (op.info['n_mit_mot'] + n_outs = (op.info['n_mit_mot'] +
op.info['n_mit_sot'] + op.info['n_mit_sot'] +
op.info['n_sit_sot']) op.info['n_sit_sot'])
for pos in xrange(n_outs):
info = copy.deepcopy(op.info)
if not 'destroy_map' in info:
info['destroy_map'] = OrderedDict()
info['destroy_map'][pos] = [pos + 1 + op.info['n_seqs']]
# inputs corresponding to sequences and n_steps
ls_begin = node.inputs[:1 + op.n_seqs]
ls = op.outer_mitmot(node.inputs)
ls += op.outer_mitsot(node.inputs)
ls += op.outer_sitsot(node.inputs)
ls_end = op.outer_shared(node.inputs)
ls_end += op.outer_nitsot(node.inputs)
ls_end += op.outer_non_seqs(node.inputs)
n_outs = len(ls)
for idx in xrange(n_outs):
if ls[idx] in ls[:idx]:
ls[idx] = deep_copy_op(ls[idx])
inputs = ls_begin + ls + ls_end
new_op = scan_op.Scan(op.inputs,
op.outputs,
info,
typeConstructor=self.typeConstructor)
# Do not call make_node for test_value node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx],
new_outs = new_op(*inputs, **dict(return_list=True)) range(n_outs))
try:
fgraph.replace_all_validate_remove( if node is original_node:
list(zip(node.outputs, new_outs)), # Making the scan compute all recurrent outputs inplace has
remove=[node], # failed. Attempt all recurrent outputs individually.
reason='scanOp_make_inplace') for pos in xrange(n_outs):
op = new_op node = self.attempt_scan_inplace(fgraph, node, [pos])
node = new_outs[0].owner
except InconsistencyError as e:
# Failed moving output to be comptued inplace
pass
class ScanSaveMem(gof.Optimizer): class ScanSaveMem(gof.Optimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论