提交 6b4b7cd3 authored 作者: --global's avatar --global

Correct inner function compilation to allow mitmot preallocation

上级 366bedec
......@@ -323,10 +323,9 @@ class Scan(PureOp):
if (hasattr(self, 'fn') and
not hasattr(self, 'thunk_mit_mot_out_slices')):
# The thunk has been compiled before mit_mot preallocation feature
# was implemented. Set up the value of self.thunk_mit_mot_outs and
# self.thunk_mit_mot_out_slices to reflect this.
self.thunk_mit_mot_out_slices = self.mit_mot_out_slices
self.thunk_mit_mot_outs = self.n_mit_mot_outs
# was implemented. Mark every mit_mot output tap as not having
# been preallocated
self.mitmots_preallocated = [False] * self.n_mit_mot_outs
# Ensure that the graph associated with the inner function is valid.
self.validate_inner_graph()
......@@ -758,18 +757,16 @@ class Scan(PureOp):
if theano.config.scan.allow_output_prealloc:
# Go through the mitmots. Whenever a mitmot has a tap both as an
# input and an output, do the following :
# - Wrap the input such that the corresponding output variable
# becomes an update to be performed on it, possibly inplace,
# at the end of the functions's execution.
# - Remove the corresponding output
# Also keep track of the updated list of output taps for mitmots
wrapped_inputs = []
# input and an output, wrap the input such that the corresponding
# output variable becomes an update to be performed on it, possibly
# inplace at the end of the functions's execution.
wrapped_inputs = [In(x, borrow=False)
for x in self.inputs[:self.n_seqs]]
new_outputs = [x for x in self.outputs]
useless_outputs = []
preallocated_outputs = []
new_mit_mot_out_slices = copy.deepcopy(self.mit_mot_out_slices)
input_idx = 0
input_idx = self.n_seqs
for mitmot_idx in range(self.n_mit_mot):
for inp_tap in self.tap_array[mitmot_idx]:
if inp_tap in self.mit_mot_out_slices[mitmot_idx]:
......@@ -781,12 +778,12 @@ class Scan(PureOp):
# Make it so the input is automatically updated to the
# output value, possibly inplace, at the end of the
# function exectution and mark the output for deletion
# function exectution
wrapped_inp = In(variable=self.inputs[input_idx],
update=self.outputs[output_idx],
borrow=False)
mutable=False, borrow=True)
wrapped_inputs.append(wrapped_inp)
useless_outputs.append(output_idx)
preallocated_outputs.append(output_idx)
new_mit_mot_out_slices[mitmot_idx].remove(inp_tap)
else:
# Wrap the corresponding input as usual. Leave the
......@@ -803,17 +800,17 @@ class Scan(PureOp):
new_outputs[:slices]]
wrapped_outputs += new_outputs[slices:]
# Delete the outputs that have are not needed anymore (start from
# the last so as not to alter the position of other outputs that
# need to be deleted)
for out_idx in useless_outputs[::-1]:
del wrapped_outputs[out_idx]
# Remove now useless outputs from the output list (start from the
# end to avoid altering the indices of the other outputs to be
# deleted.
preallocated_outputs.sort()
for p in preallocated_outputs[::-1]:
del wrapped_outputs[p]
# Store the list of mitmot output taps that the compiled thunk
# actually uses.
self.thunk_mit_mot_out_slices = new_mit_mot_out_slices
self.thunk_mit_mot_outs = sum([len(m)
for m in new_mit_mot_out_slices])
# Store the list of mitmot output taps that have been altered
# so they can be preallocated
self.mitmots_preallocated = [i in preallocated_outputs
for i in range(self.n_mit_mot_outs)]
"""
wrapped_inputs = [Param(x, borrow=False) for x in
......@@ -823,11 +820,9 @@ class Scan(PureOp):
wrapped_outputs += self.outputs[slices:]
"""
else:
# Without output preallocation, there is no manipulation of the
# mitmot output taps. Hence, the output taps used by the compiled
# thunk are the same as self.mit_mot_out_slices
self.thunk_mit_mot_out_slices = self.mit_mot_out_slices
self.thunk_mit_mot_outs = self.n_mit_mot_outs
# Output preallocation is not activated. Mark every mitmot output
# tap as not being preallocated
self.mitmots_preallocated = [False] * self.n_mit_mot_outs
wrapped_inputs = [Param(x, borrow=True) for x in
self.inputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论