提交 918c0106 authored 作者: carriepl's avatar carriepl

Code refactoring

上级 92462fe8
...@@ -291,6 +291,11 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -291,6 +291,11 @@ def inplace_elemwise_optimizer_op(OP):
nb_change_no_validate = 0 nb_change_no_validate = 0
chk = fgraph.checkpoint() chk = fgraph.checkpoint()
if fgraph.update_mapping:
update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping]
else:
update_outs = []
for node in list(graph.io_toposort(fgraph.inputs, fgraph.outputs)): for node in list(graph.io_toposort(fgraph.inputs, fgraph.outputs)):
op = node.op op = node.op
# gpuarray GpuElemwise inherit from Elemwise # gpuarray GpuElemwise inherit from Elemwise
...@@ -319,7 +324,6 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -319,7 +324,6 @@ def inplace_elemwise_optimizer_op(OP):
raised_warning = not verbose raised_warning = not verbose
update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping]
for candidate_output in candidate_outputs: for candidate_output in candidate_outputs:
# If the output of the node can be established as an update # If the output of the node can be established as an update
...@@ -328,49 +332,47 @@ def inplace_elemwise_optimizer_op(OP): ...@@ -328,49 +332,47 @@ def inplace_elemwise_optimizer_op(OP):
# inplace on the input it's meant to update # inplace on the input it's meant to update
candidate_out_var = node.outputs[candidate_output] candidate_out_var = node.outputs[candidate_output]
sorted_candidate_inputs = candidate_inputs sorted_candidate_inputs = candidate_inputs
if fgraph.update_mapping:
if candidate_out_var in update_outs:
if candidate_out_var in update_outs:
# The candidate output is an update. Sort the
# The candidate output is an update. Sort the # variables in candidate_inputs in the following order:
# variables in candidate_inputs in the following order: # - Vars corresponding to the actual updated input
# - Vars corresponding to the actual updated input # (best case scenario is for the node that procudes
# (best case scenario is for the node that procudes # an update to operate inplace on the variable to
# an update to operate inplace on the variable to # update)
# update) # - Vars computed inplace on the updates input (second
# - Vars computed inplace on the updates input (second # best scenario if for the node to work inplace on
# best scenario if for the node to work inplace on # a variable obtained by a chain of inplace on the
# a variable obtained by a chain of inplace on the # variable to update. In some cases, this will be
# variable to update. In some cases, this will be # equivalent to operating inplace on the variable to
# equivalent to operating inplace on the variable to # update)
# update) # - Remaining variables
# - Remaining variables updated_inputs = []
updated_inputs = [] for i, f_out in enumerate(fgraph.outputs):
for i, f_out in enumerate(fgraph.outputs): if (f_out is candidate_out_var and i in fgraph.update_mapping):
if (f_out is candidate_out_var and i in fgraph.update_mapping): updated_inp_idx = fgraph.update_mapping[i]
updated_inp_idx = fgraph.update_mapping[i] updated_inputs.append(fgraph.inputs[updated_inp_idx])
updated_inputs.append(fgraph.inputs[updated_inp_idx])
updated_vars = []
updated_vars = [] vars_from_inplace = []
vars_from_inplace = [] other_vars = []
other_vars = [] for inp_idx in candidate_inputs:
for inp_idx in candidate_inputs: inp = node.inputs[inp_idx]
inp = node.inputs[inp_idx] if inp in updated_inputs:
if inp in updated_inputs: updated_vars.append(inp_idx)
updated_vars.append(inp_idx) elif (hasattr(fgraph, 'destroy_handler') and
elif (hasattr(fgraph, 'destroy_handler') and inp.owner and
inp.owner and any([(up_inp in fgraph.destroy_handler.root_destroyer and
any([(up_inp in fgraph.destroy_handler.root_destroyer and fgraph.destroy_handler.root_destroyer[up_inp] is inp.owner)
fgraph.destroy_handler.root_destroyer[up_inp] is inp.owner) for up_inp in updated_inputs])):
for up_inp in updated_inputs])):
vars_from_inplace.append(inp_idx)
vars_from_inplace.append(inp_idx) else:
else: other_vars.append(inp_idx)
other_vars.append(inp_idx)
sorted_candidate_inputs = (updated_vars +
sorted_candidate_inputs = (updated_vars + vars_from_inplace + other_vars)
vars_from_inplace +
other_vars)
for candidate_input in sorted_candidate_inputs: for candidate_input in sorted_candidate_inputs:
# remove inputs that don't have the same dtype as the output # remove inputs that don't have the same dtype as the output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论