提交 644d8bce authored 作者: Frederic Bastien's avatar Frederic Bastien

inplace elemwise optimizer speed up.

上级 fce0b82f
...@@ -237,6 +237,11 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -237,6 +237,11 @@ def inplace_elemwise_optimizer_op(OP):
else: else:
update_outs = [] update_outs = []
protected_inputs = [
f.protected for f in fgraph._features if
isinstance(f, theano.compile.function_module.Supervisor)]
protected_inputs = sum(protected_inputs, []) # flatten the list
protected_inputs.extend(fgraph.outputs)
for node in list(graph.io_toposort(fgraph.inputs, fgraph.outputs)): for node in list(graph.io_toposort(fgraph.inputs, fgraph.outputs)):
op = node.op op = node.op
# gpuarray GpuElemwise inherit from Elemwise # gpuarray GpuElemwise inherit from Elemwise
...@@ -245,27 +250,37 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -245,27 +250,37 @@ def inplace_elemwise_optimizer_op(OP):
# If big graph and the outputs are scalar, do not make it # If big graph and the outputs are scalar, do not make it
# inplace. # inplace.
if (check_each_change != 1 and if (check_each_change != 1 and
all([getattr(o.type, 'ndim', -1) == 0 # If multiple outputs, they must all have the same size,
for o in node.outputs])): # so only check the first.
getattr(node.outputs[0].type, 'ndim', -1) == 0):
continue continue
baseline = op.inplace_pattern if op.inplace_pattern:
protected_inputs = [ baseline = op.inplace_pattern
f.protected for f in node.fgraph._features if candidate_outputs = [i for i in xrange(len(node.outputs))
isinstance(f, theano.compile.function_module.Supervisor)] if i not in baseline]
protected_inputs = sum(protected_inputs, []) # flatten the list # node inputs that are Constant, already destroyed,
protected_inputs.extend(fgraph.outputs) # or fgraph protected inputs and fgraph outputs can't be used as
candidate_outputs = [i for i in xrange(len(node.outputs)) # inplace target.
if i not in baseline] # Remove here as faster.
# node inputs that are Constant, already destroyed, candidate_inputs = [i for i in xrange(len(node.inputs))
# fgraph protected inputs and fgraph outputs can't be used as inplace if i not in baseline.values() and
# target. not isinstance(node.inputs[i], Constant) and
# Remove here as faster. # Is next line costly?
candidate_inputs = [i for i in xrange(len(node.inputs)) not fgraph.destroyers(node.inputs[i]) and
if i not in baseline.values() and node.inputs[i] not in protected_inputs]
not isinstance(node.inputs[i], Constant) and else:
not fgraph.destroyers(node.inputs[i]) and baseline = []
node.inputs[i] not in protected_inputs] candidate_outputs = list(range(len(node.outputs)))
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
# Remove here as faster.
candidate_inputs = [i for i in xrange(len(node.inputs))
if not isinstance(node.inputs[i], Constant) and
not fgraph.destroyers(node.inputs[i]) and
node.inputs[i] not in protected_inputs]
verbose = False verbose = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论