提交 d7271b5e authored 作者: Bryn Keller's avatar Bryn Keller

Tuned insert_deepcopy

上级 2f0ab791
......@@ -1040,11 +1040,9 @@ def _constructor_Function(maker, input_storage, inputs_data):
copyreg.pickle(Function, _pickle_Function)
###
# FunctionMaker
###
def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
"""
Insert deepcopy in the fgraph to break aliasing of outputs
......@@ -1060,17 +1058,18 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# memory contract
# We don't insert deep copy when the output.borrow is True for all
# conserned outputs.
# concerned outputs.
assert len(wrapped_inputs) == len(fgraph.inputs)
assert len(wrapped_outputs) == len(fgraph.outputs)
reason = "insert_deepcopy"
updated_fgraph_inputs = [fgraph_i for i, fgraph_i in
updated_fgraph_inputs = set([fgraph_i for i, fgraph_i in
zip(wrapped_inputs, fgraph.inputs)
if getattr(i, 'update', False)]
if getattr(i, 'update', False)])
# We can't use fgraph.inputs as this don't include Constant Value.
all_graph_inputs = gof.graph.inputs(fgraph.outputs)
has_destroyers = hasattr(fgraph, 'get_destroyers_of')
for i in xrange(len(fgraph.outputs)):
views_of_output_i = set()
......@@ -1099,12 +1098,9 @@ def insert_deepcopy(fgraph, wrapped_inputs, wrapped_outputs):
# e.g. in-place computations
# b) that j'th input is a shared variable that is also
# being updated
if (hasattr(fgraph, 'get_destroyers_of') and
fgraph.get_destroyers_of(input_j)):
continue
if input_j in updated_fgraph_inputs:
continue
if input_j in views_of_output_i:
if input_j in views_of_output_i and not (has_destroyers and fgraph.get_destroyers_of(input_j)):
# We don't put deep_copy_op if the input and the
# output have borrow==True
if input_j in fgraph.inputs:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论