提交 87d55476 authored 作者: carriepl's avatar carriepl

Make ScanInplaceOptimizer handle GpuAlloc and GpuAllocEmpty from both backends

上级 e6aecd41
......@@ -1010,7 +1010,7 @@ class ScanInplaceOptimizer(Optimizer):
fgraph.attach_feature(toolbox.ReplaceValidate())
fgraph.attach_feature(DestroyHandler())
def attempt_scan_inplace(self, fgraph, node, output_indices):
def attempt_scan_inplace(self, fgraph, node, output_indices, alloc_ops):
"""Attempts to replace a Scan node by one which computes the specified
outputs inplace.
......@@ -1022,6 +1022,10 @@ class ScanInplaceOptimizer(Optimizer):
Scan node to replace by an inplace version
output_indices : list of integers
Indices of the outputs to attempt to compute inplace
alloc_ops : list of Op classes
Classes that represent operation that allocate new memory and
that the optimization should duplicate so it can operate inplace
on them.
"""
op = node.op
......@@ -1043,11 +1047,11 @@ class ScanInplaceOptimizer(Optimizer):
ls_end += op.outer_non_seqs(node.inputs)
# In `ls`, duplicate any input which has more then one client and is
# an Alloc or an AllocEmpty
# the output of an eligible allocation op
for i in range(len(ls)):
inp = ls[i]
if (len(inp.clients) > 1 and inp.owner and
isinstance(inp.owner.op, (Alloc, AllocEmpty))):
isinstance(inp.owner.op, alloc_ops)):
ls[i] = inp.owner.op(*inp.owner.inputs)
n_outs = len(ls)
......@@ -1080,6 +1084,21 @@ class ScanInplaceOptimizer(Optimizer):
def apply(self, fgraph):
# Depending on the values of gpu_flag and gpua_flag, get the list of
# memory allocation ops that the optimization should be able to handle
alloc_ops = (Alloc, AllocEmpty)
if self.gpu_flag:
alloc_ops += (theano.sandbox.cuda.GpuAlloc,
theano.sandbox.cuda.GpuAllocEmpty)
if self.gpua_flag:
# gpuarray might be imported but not its GpuAlloc and
# GpuAllopEmpty ops.
try:
alloc_ops += (theano.sandbox.gpuarray.GpuAlloc,
theano.sandbox.gpuarray.GpuAllocEmpty)
except:
pass
nodes = fgraph.toposort()[::-1]
scan_nodes = [x for x in nodes
if (isinstance(x.op, scan_op.Scan) and
......@@ -1104,14 +1123,14 @@ class ScanInplaceOptimizer(Optimizer):
inp_idx = 1 + op.n_seqs + out_idx
inp = original_node.inputs[inp_idx]
# If the input is an Alloc or an AllocEmpty, attempt to be
# inplace on it, even if other nodes are modifying it
# If the input is from an eligible allocation node, attempt to
# be inplace on it, even if other nodes are modifying it
# inplace.
if inp.owner and isinstance(inp.owner.op, (Alloc, AllocEmpty)):
if inp.owner and isinstance(inp.owner.op, alloc_ops):
out_indices.append(out_idx)
continue
# If the input is neither an Alloc or an AllocEmpty, only
# If the input is not from an eligible allocation node, only
# attempt to be inplace on it if nothing else is currently
# inplace on it.
input_used_inplace = False
......@@ -1134,14 +1153,15 @@ class ScanInplaceOptimizer(Optimizer):
out_indices.append(out_idx)
node = self.attempt_scan_inplace(fgraph, scan_nodes[scan_idx],
out_indices)
out_indices, alloc_ops)
if node is original_node:
# Making the scan compute all plausible recurrent outputs
# inplace has failed. Attempt all plausible recurrent output
# individually.
for pos in out_indices:
node = self.attempt_scan_inplace(fgraph, node, [pos])
node = self.attempt_scan_inplace(fgraph, node, [pos],
alloc_ops)
class ScanSaveMem(gof.Optimizer):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论