提交 185f40d8 authored 作者: --global's avatar --global

Implement input preference when making elemwise ops inplace

上级 429b4469
......@@ -320,7 +320,50 @@ def inplace_elemwise_optimizer_op(OP):
raised_warning = not verbose
for candidate_output in candidate_outputs:
for candidate_input in candidate_inputs:
# If the output of the node can be established as an update
# output of the fgraph, visit the candidate_inputs in an order
# that will improve the chances of making the node operate
# inplace on the input it's meant to update
candidate_out_var = node.outputs[candidate_output]
sorted_candidate_inputs = candidate_inputs
if fgraph.update_mapping:
update_outs = [fgraph.outputs[i]
for i in fgraph.update_mapping]
if candidate_out_var in update_outs:
# The candidate output is an update. Sort the
# variables in candidate_inputs in the following order:
# - Vars corresponding to the actual updated input
# - Vars computed inplace on the updates input
# - Remaining variables
fgraph_out_idx = fgraph.outputs.index(candidate_out_var)
updated_inp_idx = fgraph.update_mapping[fgraph_out_idx]
updated_inp = fgraph.inputs[updated_inp_idx]
updated_vars = []
vars_from_inplace = []
other_vars = []
for inp_idx in candidate_inputs:
inp = node.inputs[inp_idx]
if inp is updated_inp:
updated_vars.append(inp_idx)
elif (hasattr(fgraph, 'destroy_handler') and
inp.owner and
updated_inp in fgraph.destroy_handler.root_destroyer and
fgraph.destroy_handler.root_destroyer[updated_inp] == inp.owner):
vars_from_inplace.append(inp_idx)
else:
other_vars.append(inp_idx)
sorted_candidate_inputs = (updated_vars +
vars_from_inplace +
other_vars)
for candidate_input in sorted_candidate_inputs:
# remove inputs that don't have the same dtype as the output
if node.inputs[candidate_input].type != node.outputs[
candidate_output].type:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论