提交 9ad8e925 authored 作者: Reyhane Askari's avatar Reyhane Askari

replaced get_destroyers and fgraph.destroyers with has_destroyers

上级 4a3e0e49
...@@ -195,7 +195,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False): ...@@ -195,7 +195,7 @@ def std_fgraph(input_specs, output_specs, accept_inplace=False):
for spec, input in zip(input_specs, fgraph.inputs) for spec, input in zip(input_specs, fgraph.inputs)
if not (spec.mutable or if not (spec.mutable or
(hasattr(fgraph, 'destroyers') and (hasattr(fgraph, 'destroyers') and
fgraph.destroyers(input))))) fgraph.has_destroyers([input])))))
# If named nodes are replaced, keep the name # If named nodes are replaced, keep the name
for feature in std_fgraph.features: for feature in std_fgraph.features:
...@@ -1090,7 +1090,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1090,7 +1090,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# We can't use fgraph.inputs as this don't include Constant Value. # We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs) all_graph_inputs = gof.graph.inputs(fgraph.outputs)
has_get_destroyers = hasattr(fgraph, 'get_destroyers_of') has_destroyers_attr = hasattr(fgraph, 'has_destroyers')
for i in xrange(len(fgraph.outputs)): for i in xrange(len(fgraph.outputs)):
views_of_output_i = set() views_of_output_i = set()
...@@ -1121,7 +1121,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1121,7 +1121,7 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# being updated # being updated
if input_j in updated_fgraph_inputs: if input_j in updated_fgraph_inputs:
continue continue
if input_j in views_of_output_i and not (has_get_destroyers and fgraph.get_destroyers_of(input_j)): if input_j in views_of_output_i and not (has_destroyers_attr and fgraph.has_destroyers([input_j])):
# We don't put deep_copy_op if the input and the # We don't put deep_copy_op if the input and the
# output have borrow==True # output have borrow==True
if input_j in fgraph.inputs: if input_j in fgraph.inputs:
......
...@@ -250,7 +250,7 @@ def fast_inplace_check(inputs): ...@@ -250,7 +250,7 @@ def fast_inplace_check(inputs):
inputs = [i for i in inputs if inputs = [i for i in inputs if
not isinstance(i, graph.Constant) and not isinstance(i, graph.Constant) and
not fgraph.destroyers(i) and not fgraph.has_destroyers([i]) and
i not in protected_inputs] i not in protected_inputs]
return inputs return inputs
......
...@@ -265,8 +265,8 @@ class InplaceElemwiseOptimizer(Optimizer): ...@@ -265,8 +265,8 @@ class InplaceElemwiseOptimizer(Optimizer):
candidate_inputs = [i for i in xrange(len(node.inputs)) candidate_inputs = [i for i in xrange(len(node.inputs))
if i not in baseline.values() and if i not in baseline.values() and
not isinstance(node.inputs[i], Constant) and not isinstance(node.inputs[i], Constant) and
# Is next line costly? # Is next line costly? (used to be fgraph.get_destroyers)
not fgraph.destroyers(node.inputs[i]) and not fgraph.has_destroyers([node.inputs[i]]) and
node.inputs[i] not in protected_inputs] node.inputs[i] not in protected_inputs]
else: else:
baseline = [] baseline = []
...@@ -277,7 +277,7 @@ class InplaceElemwiseOptimizer(Optimizer): ...@@ -277,7 +277,7 @@ class InplaceElemwiseOptimizer(Optimizer):
# Remove here as faster. # Remove here as faster.
candidate_inputs = [i for i in xrange(len(node.inputs)) candidate_inputs = [i for i in xrange(len(node.inputs))
if not isinstance(node.inputs[i], Constant) and if not isinstance(node.inputs[i], Constant) and
not fgraph.destroyers(node.inputs[i]) and not fgraph.has_destroyers([node.inputs[i]]) and
node.inputs[i] not in protected_inputs] node.inputs[i] not in protected_inputs]
verbose = False verbose = False
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论