提交 918c0106 authored 作者: carriepl's avatar carriepl

Code refactoring

上级 92462fe8
......@@ -291,6 +291,11 @@ def inplace_elemwise_optimizer_op(OP):
nb_change_no_validate = 0
chk = fgraph.checkpoint()
if fgraph.update_mapping:
update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping]
else:
update_outs = []
for node in list(graph.io_toposort(fgraph.inputs, fgraph.outputs)):
op = node.op
# gpuarray GpuElemwise inherit from Elemwise
......@@ -319,7 +324,6 @@ def inplace_elemwise_optimizer_op(OP):
raised_warning = not verbose
update_outs = [fgraph.outputs[i] for i in fgraph.update_mapping]
for candidate_output in candidate_outputs:
# If the output of the node can be established as an update
......@@ -328,7 +332,6 @@ def inplace_elemwise_optimizer_op(OP):
# inplace on the input it's meant to update
candidate_out_var = node.outputs[candidate_output]
sorted_candidate_inputs = candidate_inputs
if fgraph.update_mapping:
if candidate_out_var in update_outs:
......@@ -369,8 +372,7 @@ def inplace_elemwise_optimizer_op(OP):
other_vars.append(inp_idx)
sorted_candidate_inputs = (updated_vars +
vars_from_inplace +
other_vars)
vars_from_inplace + other_vars)
for candidate_input in sorted_candidate_inputs:
# remove inputs that don't have the same dtype as the output
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论