提交 6b4b7cd3 authored 作者: --global's avatar --global

Correct inner function compilation to allow mitmot preallocation

上级 366bedec
...@@ -323,10 +323,9 @@ class Scan(PureOp): ...@@ -323,10 +323,9 @@ class Scan(PureOp):
if (hasattr(self, 'fn') and if (hasattr(self, 'fn') and
not hasattr(self, 'thunk_mit_mot_out_slices')): not hasattr(self, 'thunk_mit_mot_out_slices')):
# The thunk has been compiled before mit_mot preallocation feature # The thunk has been compiled before mit_mot preallocation feature
# was implemented. Set up the value of self.thunk_mit_mot_outs and # was implemented. Mark every mit_mot output tap as not having
# self.thunk_mit_mot_out_slices to reflect this. # been preallocated
self.thunk_mit_mot_out_slices = self.mit_mot_out_slices self.mitmots_preallocated = [False] * self.n_mit_mot_outs
self.thunk_mit_mot_outs = self.n_mit_mot_outs
# Ensure that the graph associated with the inner function is valid. # Ensure that the graph associated with the inner function is valid.
self.validate_inner_graph() self.validate_inner_graph()
...@@ -758,18 +757,16 @@ class Scan(PureOp): ...@@ -758,18 +757,16 @@ class Scan(PureOp):
if theano.config.scan.allow_output_prealloc: if theano.config.scan.allow_output_prealloc:
# Go through the mitmots. Whenever a mitmot has a tap both as an # Go through the mitmots. Whenever a mitmot has a tap both as an
# input and an output, do the following : # input and an output, wrap the input such that the corresponding
# - Wrap the input such that the corresponding output variable # output variable becomes an update to be performed on it, possibly
# becomes an update to be performed on it, possibly inplace, # inplace at the end of the functions's execution.
# at the end of the functions's execution. wrapped_inputs = [In(x, borrow=False)
# - Remove the corresponding output for x in self.inputs[:self.n_seqs]]
# Also keep track of the updated list of output taps for mitmots
wrapped_inputs = []
new_outputs = [x for x in self.outputs] new_outputs = [x for x in self.outputs]
useless_outputs = [] preallocated_outputs = []
new_mit_mot_out_slices = copy.deepcopy(self.mit_mot_out_slices) new_mit_mot_out_slices = copy.deepcopy(self.mit_mot_out_slices)
input_idx = 0 input_idx = self.n_seqs
for mitmot_idx in range(self.n_mit_mot): for mitmot_idx in range(self.n_mit_mot):
for inp_tap in self.tap_array[mitmot_idx]: for inp_tap in self.tap_array[mitmot_idx]:
if inp_tap in self.mit_mot_out_slices[mitmot_idx]: if inp_tap in self.mit_mot_out_slices[mitmot_idx]:
...@@ -781,12 +778,12 @@ class Scan(PureOp): ...@@ -781,12 +778,12 @@ class Scan(PureOp):
# Make it so the input is automatically updated to the # Make it so the input is automatically updated to the
# output value, possibly inplace, at the end of the # output value, possibly inplace, at the end of the
# function exectution and mark the output for deletion # function exectution
wrapped_inp = In(variable=self.inputs[input_idx], wrapped_inp = In(variable=self.inputs[input_idx],
update=self.outputs[output_idx], update=self.outputs[output_idx],
borrow=False) mutable=False, borrow=True)
wrapped_inputs.append(wrapped_inp) wrapped_inputs.append(wrapped_inp)
useless_outputs.append(output_idx) preallocated_outputs.append(output_idx)
new_mit_mot_out_slices[mitmot_idx].remove(inp_tap) new_mit_mot_out_slices[mitmot_idx].remove(inp_tap)
else: else:
# Wrap the corresponding input as usual. Leave the # Wrap the corresponding input as usual. Leave the
...@@ -803,17 +800,17 @@ class Scan(PureOp): ...@@ -803,17 +800,17 @@ class Scan(PureOp):
new_outputs[:slices]] new_outputs[:slices]]
wrapped_outputs += new_outputs[slices:] wrapped_outputs += new_outputs[slices:]
# Delete the outputs that have are not needed anymore (start from # Remove now useless outputs from the output list (start from the
# the last so as not to alter the position of other outputs that # end to avoid altering the indices of the other outputs to be
# need to be deleted) # deleted.
for out_idx in useless_outputs[::-1]: preallocated_outputs.sort()
del wrapped_outputs[out_idx] for p in preallocated_outputs[::-1]:
del wrapped_outputs[p]
# Store the list of mitmot output taps that the compiled thunk # Store the list of mitmot output taps that have been altered
# actually uses. # so they can be preallocated
self.thunk_mit_mot_out_slices = new_mit_mot_out_slices self.mitmots_preallocated = [i in preallocated_outputs
self.thunk_mit_mot_outs = sum([len(m) for i in range(self.n_mit_mot_outs)]
for m in new_mit_mot_out_slices])
""" """
wrapped_inputs = [Param(x, borrow=False) for x in wrapped_inputs = [Param(x, borrow=False) for x in
...@@ -823,11 +820,9 @@ class Scan(PureOp): ...@@ -823,11 +820,9 @@ class Scan(PureOp):
wrapped_outputs += self.outputs[slices:] wrapped_outputs += self.outputs[slices:]
""" """
else: else:
# Without output preallocation, there is no manipulation of the # Output preallocation is not activated. Mark every mitmot output
# mitmot output taps. Hence, the output taps used by the compiled # tap as not being preallocated
# thunk are the same as self.mit_mot_out_slices self.mitmots_preallocated = [False] * self.n_mit_mot_outs
self.thunk_mit_mot_out_slices = self.mit_mot_out_slices
self.thunk_mit_mot_outs = self.n_mit_mot_outs
wrapped_inputs = [Param(x, borrow=True) for x in wrapped_inputs = [Param(x, borrow=True) for x in
self.inputs] self.inputs]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论