提交 994f4e91 authored 作者: Pascal Lamblin's avatar Pascal Lamblin 提交者: GitHub

Merge pull request #5058 from xoltar/careful-insert_deepcopy

Tuned insert_deepcopy
...@@ -1044,7 +1044,6 @@ copyreg.pickle(Function, _pickle_Function) ...@@ -1044,7 +1044,6 @@ copyreg.pickle(Function, _pickle_Function)
### ###
# FunctionMaker # FunctionMaker
### ###
def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
""" """
Insert deepcopy in the fgraph to break aliasing of outputs Insert deepcopy in the fgraph to break aliasing of outputs
...@@ -1060,17 +1059,18 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1060,17 +1059,18 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# memory contract # memory contract
# We don't insert deep copy when the output.borrow is True for all # We don't insert deep copy when the output.borrow is True for all
# conserned outputs. # concerned outputs.
assert len(wrapped_inputs) == len(fgraph.inputs) assert len(wrapped_inputs) == len(fgraph.inputs)
assert len(wrapped_outputs) == len(fgraph.outputs) assert len(wrapped_outputs) == len(fgraph.outputs)
reason = "insert_deepcopy" reason = "insert_deepcopy"
updated_fgraph_inputs = [fgraph_i for i, fgraph_i in updated_fgraph_inputs = set([fgraph_i for i, fgraph_i in
zip(wrapped_inputs, fgraph.inputs) zip(wrapped_inputs, fgraph.inputs)
if getattr(i, 'update', False)] if getattr(i, 'update', False)])
# We can't use fgraph.inputs as this don't include Constant Value. # We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs) all_graph_inputs = gof.graph.inputs(fgraph.outputs)
has_destroyers = hasattr(fgraph, 'get_destroyers_of')
for i in xrange(len(fgraph.outputs)): for i in xrange(len(fgraph.outputs)):
views_of_output_i = set() views_of_output_i = set()
...@@ -1099,12 +1099,9 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs): ...@@ -1099,12 +1099,9 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# e.g. in-place computations # e.g. in-place computations
# b) that j'th input is a shared variable that is also # b) that j'th input is a shared variable that is also
# being updated # being updated
if (hasattr(fgraph, 'get_destroyers_of') and
fgraph.get_destroyers_of(input_j)):
continue
if input_j in updated_fgraph_inputs: if input_j in updated_fgraph_inputs:
continue continue
if input_j in views_of_output_i: if input_j in views_of_output_i and not (has_destroyers and fgraph.get_destroyers_of(input_j)):
# We don't put deep_copy_op if the input and the # We don't put deep_copy_op if the input and the
# output have borrow==True # output have borrow==True
if input_j in fgraph.inputs: if input_j in fgraph.inputs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论