提交 11964f0a authored 作者: --global's avatar --global

Fix and apply memory reuse in Scan's python backend

上级 9dc07802
...@@ -674,8 +674,8 @@ class Scan(PureOp): ...@@ -674,8 +674,8 @@ class Scan(PureOp):
self.n_mit_sot + self.n_mit_sot +
self.n_sit_sot + self.n_sit_sot +
self.n_nit_sot) self.n_nit_sot)
wrapped_inputs = [Param(x, borrow=True) for x in self.inputs] wrapped_inputs = [Param(x, borrow=False) for x in self.inputs]
wrapped_outputs = [Out(x, borrow=False) for x in wrapped_outputs = [Out(x, borrow=(x not in self.inputs)) for x in
self.outputs[:slices]] self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:] wrapped_outputs += self.outputs[slices:]
profile = None profile = None
...@@ -1021,6 +1021,9 @@ class Scan(PureOp): ...@@ -1021,6 +1021,9 @@ class Scan(PureOp):
other_args = args[offset:] other_args = args[offset:]
input_storage = self.fn.input_storage input_storage = self.fn.input_storage
output_storage = self.fn.output_storage output_storage = self.fn.output_storage
old_output_storage = [None] * len(output_storage)
old_output_data = [None] * len(output_storage)
output_reused = [None] * len(output_storage)
fn = self.fn.fn fn = self.fn.fn
offset = (self.n_seqs + sum(map(len, self.tap_array[:self.n_outs])) + offset = (self.n_seqs + sum(map(len, self.tap_array[:self.n_outs])) +
self.n_shared_outs) self.n_shared_outs)
...@@ -1100,9 +1103,19 @@ class Scan(PureOp): ...@@ -1100,9 +1103,19 @@ class Scan(PureOp):
output_storage[pdx].storage[0] = None output_storage[pdx].storage[0] = None
# 4.5. Keep a reference to the variables currently in the # 4.5. Keep a reference to the variables currently in the
# output_storage to be able to compare them with the actual # output_storage, and their data, to be able to compare them with
# outputs of the inner function after its execution # the actual outputs of the inner function after its execution
old_output_storage = [o.storage[0] for o in output_storage] for idx in xrange(len(output_storage)):
var = output_storage[idx].storage[0]
old_output_storage[idx] = var
if hasattr(var, 'gpudata'):
old_output_data[idx] = var.gpudata
elif hasattr(var, 'data'):
old_output_data[idx] = var.data
else:
old_output_data[idx] = None
# 5. compute outputs # 5. compute outputs
t0_fn = time.time() t0_fn = time.time()
...@@ -1134,9 +1147,26 @@ class Scan(PureOp): ...@@ -1134,9 +1147,26 @@ class Scan(PureOp):
# Check which of the pre-allocated outputs (if applicable) have # Check which of the pre-allocated outputs (if applicable) have
# been reused by the inner function # been reused by the inner function
output_reused = [old_output_storage[o] is for idx in xrange(len(output_storage)):
output_storage[o].storage[0] # If the storage map does not contain the same object, then
for o in range(len(output_storage))] # the pre-allocated output has not been reused
new_var = output_storage[idx].storage[0]
if old_output_storage[idx] is new_var:
# The pre-allocated output is only considered as having
# been reused if it still points to the same data as it
# did before the execution of the inner function
if old_output_data[idx] is None:
output_reused[idx] = False
else:
if hasattr(new_var, 'gpudata'):
output_reused[idx] = (new_var.gpudata ==
old_output_data[idx])
elif hasattr(new_var, 'data'):
output_reused[idx] = (new_var.data ==
old_output_data[idx])
else:
output_reused[idx] = False
t_fn += dt_fn t_fn += dt_fn
offset_out = 0 offset_out = 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论