提交 87d55476 authored 作者: carriepl's avatar carriepl

Make ScanInplaceOptimizer handle GpuAlloc and GpuAllocEmpty from both backends

上级 e6aecd41
...@@ -1010,7 +1010,7 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1010,7 +1010,7 @@ class ScanInplaceOptimizer(Optimizer):
fgraph.attach_feature(toolbox.ReplaceValidate()) fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler()) fgraph.attach_feature(DestroyHandler())
def attempt_scan_inplace(self, fgraph, node, output_indices): def attempt_scan_inplace(self, fgraph, node, output_indices, alloc_ops):
"""Attempts to replace a Scan node by one which computes the specified """Attempts to replace a Scan node by one which computes the specified
outputs inplace. outputs inplace.
...@@ -1022,6 +1022,10 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1022,6 +1022,10 @@ class ScanInplaceOptimizer(Optimizer):
Scan node to replace by an inplace version Scan node to replace by an inplace version
output_indices : list of integers output_indices : list of integers
Indices of the outputs to attempt to compute inplace Indices of the outputs to attempt to compute inplace
alloc_ops : list of Op classes
Classes that represent operation that allocate new memory and
that the optimization should duplicate so it can operate inplace
on them.
""" """
op = node.op op = node.op
...@@ -1043,11 +1047,11 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1043,11 +1047,11 @@ class ScanInplaceOptimizer(Optimizer):
ls_end += op.outer_non_seqs(node.inputs) ls_end += op.outer_non_seqs(node.inputs)
# In `ls`, duplicate any input which has more then one client and is # In `ls`, duplicate any input which has more then one client and is
# an Alloc or an AllocEmpty # the output of an eligible allocation op
for i in range(len(ls)): for i in range(len(ls)):
inp = ls[i] inp = ls[i]
if (len(inp.clients) > 1 and inp.owner and if (len(inp.clients) > 1 and inp.owner and
isinstance(inp.owner.op, (Alloc, AllocEmpty))): isinstance(inp.owner.op, alloc_ops)):
ls[i] = inp.owner.op(*inp.owner.inputs) ls[i] = inp.owner.op(*inp.owner.inputs)
n_outs = len(ls) n_outs = len(ls)
...@@ -1080,6 +1084,21 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1080,6 +1084,21 @@ class ScanInplaceOptimizer(Optimizer):
def apply(self, fgraph): def apply(self, fgraph):
# Depending on the values of gpu_flag and gpua_flag, get the list of
# memory allocation ops that the optimization should be able to handle
alloc_ops = (Alloc, AllocEmpty)
if self.gpu_flag:
alloc_ops += (theano.sandbox.cuda.GpuAlloc,
theano.sandbox.cuda.GpuAllocEmpty)
if self.gpua_flag:
# gpuarray might be imported but not its GpuAlloc and
# GpuAllopEmpty ops.
try:
alloc_ops += (theano.sandbox.gpuarray.GpuAlloc,
theano.sandbox.gpuarray.GpuAllocEmpty)
except:
pass
nodes = fgraph.toposort()[::-1] nodes = fgraph.toposort()[::-1]
scan_nodes = [x for x in nodes scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and if (isinstance(x.op, scan_op.Scan) and
...@@ -1104,14 +1123,14 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1104,14 +1123,14 @@ class ScanInplaceOptimizer(Optimizer):
inp_idx = 1 + op.n_seqs + out_idx inp_idx = 1 + op.n_seqs + out_idx
inp = original_node.inputs[inp_idx] inp = original_node.inputs[inp_idx]
# If the input is an Alloc or an AllocEmpty, attempt to be # If the input is from an eligible allocation node, attempt to
# inplace on it, even if other nodes are modifying it # be inplace on it, even if other nodes are modifying it
# inplace. # inplace.
if inp.owner and isinstance(inp.owner.op, (Alloc, AllocEmpty)): if inp.owner and isinstance(inp.owner.op, alloc_ops):
out_indices.append(out_idx) out_indices.append(out_idx)
continue continue
# If the input is neither an Alloc or an AllocEmpty, only # If the input is not from an eligible allocation node, only
# attempt to be inplace on it if nothing else is currently # attempt to be inplace on it if nothing else is currently
# inplace on it. # inplace on it.
input_used_inplace = False input_used_inplace = False
...@@ -1134,14 +1153,15 @@ class ScanInplaceOptimizer(Optimizer): ...@@ -1134,14 +1153,15 @@ class ScanInplaceOptimizer(Optimizer):
out_indices.append(out_idx) out_indices.append(out_idx)
node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx], node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx],
out_indices) out_indices, alloc_ops)
if node is original_node: if node is original_node:
# Making the scan compute all plausible recurrent outputs # Making the scan compute all plausible recurrent outputs
# inplace has failed. Attempt all plausible recurrent output # inplace has failed. Attempt all plausible recurrent output
# individually. # individually.
for pos in out_indices: for pos in out_indices:
node = self.attempt_scan_inplace(fgraph, node, [pos]) node = self.attempt_scan_inplace(fgraph, node, [pos],
alloc_ops)
class ScanSaveMem(gof.Optimizer): class ScanSaveMem(gof.Optimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论