提交 644d8bce authored 作者: Frederic Bastien's avatar Frederic Bastien

inplace elemwise optimizer speed up.

上级 fce0b82f
......@@ -237,6 +237,11 @@ def inplace_elemwise_optimizer_op(OP):
else:
update_outs = []
protected_inputs = [
f.protected for f in fgraph._features if
isinstance(f, theano.compile.function_module.Supervisor)]
protected_inputs = sum(protected_inputs, []) # flatten the list
protected_inputs.extend(fgraph.outputs)
for node in list(graph.io_toposort(fgraph.inputs, fgraph.outputs)):
op = node.op
# gpuarray GpuElemwise inherit from Elemwise
......@@ -245,27 +250,37 @@ def inplace_elemwise_optimizer_op(OP):
# If big graph and the outputs are scalar, do not make it
# inplace.
if (check_each_change != 1 and
all([getattr(o.type, 'ndim', -1) == 0
for o in node.outputs])):
# If multiple outputs, they must all have the same size,
# so only check the first.
getattr(node.outputs[0].type, 'ndim', -1) == 0):
continue
baseline = op.inplace_pattern
protected_inputs = [
f.protected for f in node.fgraph._features if
isinstance(f, theano.compile.function_module.Supervisor)]
protected_inputs = sum(protected_inputs, []) # flatten the list
protected_inputs.extend(fgraph.outputs)
candidate_outputs = [i for i in xrange(len(node.outputs))
if i not in baseline]
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
# Remove here as faster.
candidate_inputs = [i for i in xrange(len(node.inputs))
if i not in baseline.values() and
not isinstance(node.inputs[i], Constant) and
not fgraph.destroyers(node.inputs[i]) and
node.inputs[i] not in protected_inputs]
if op.inplace_pattern:
baseline = op.inplace_pattern
candidate_outputs = [i for i in xrange(len(node.outputs))
if i not in baseline]
# node inputs that are Constant, already destroyed,
# or fgraph protected inputs and fgraph outputs can't be used as
# inplace target.
# Remove here as faster.
candidate_inputs = [i for i in xrange(len(node.inputs))
if i not in baseline.values() and
not isinstance(node.inputs[i], Constant) and
# Is next line costly?
not fgraph.destroyers(node.inputs[i]) and
node.inputs[i] not in protected_inputs]
else:
baseline = []
candidate_outputs = list(range(len(node.outputs)))
# node inputs that are Constant, already destroyed,
# fgraph protected inputs and fgraph outputs can't be used as inplace
# target.
# Remove here as faster.
candidate_inputs = [i for i in xrange(len(node.inputs))
if not isinstance(node.inputs[i], Constant) and
not fgraph.destroyers(node.inputs[i]) and
node.inputs[i] not in protected_inputs]
verbose = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论