提交 94366583 authored 作者: --global's avatar --global

Alter scan.make_thunk to handle preallocation of mitmot taps both inputs and outputs

上级 1e1f5426
......@@ -56,6 +56,7 @@ __authors__ = ("Razvan Pascanu "
__copyright__ = "(c) 2010, Universite de Montreal"
__contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import copy
import itertools
import logging
import time
......@@ -66,7 +67,7 @@ from six.moves import xrange
import theano
from theano.compat import exc_message
from theano.compile import function, Param, Out
from theano.compile import function, In, Param, Out
from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply
from theano.gof.graph import io_connection_pattern
......@@ -746,17 +747,87 @@ class Scan(PureOp):
self.n_mit_sot +
self.n_sit_sot +
self.n_nit_sot)
if theano.config.scan.allow_output_prealloc:
# Go through the mitmots. Whenever a mitmot has a tap both as an
# input and an output, do the following :
# - Wrap the input such that the corresponding output variable
# becomes an update to be performed on it, possibly inplace,
# at the end of the functions's execution.
# - Remove the corresponding output
# Also keep track of the updated list of output taps for mitmots
wrapped_inputs = []
new_outputs = [x for x in self.outputs]
useless_outputs = []
new_mit_mot_out_slices = copy.deepcopy(self.mit_mot_out_slices)
input_idx = 0
for mitmot_idx in range(self.n_mit_mot):
for inp_tap in self.tap_array[mitmot_idx]:
if inp_tap in self.mit_mot_out_slices[mitmot_idx]:
# Figure out the index of the corresponding output
output_idx = sum([len(m) for m in
self.mit_mot_out_slices[:mitmot_idx]])
output_idx += self.mit_mot_out_slices[mitmot_idx].index(inp_tap)
# Make it so the input is automatically updated to the
# output value, possibly inplace, at the end of the
# function exectution and mark the output for deletion
wrapped_inp = In(variable=self.inputs[input_idx],
update=self.outputs[output_idx],
borrow=False)
wrapped_inputs.append(wrapped_inp)
useless_outputs.append(output_idx)
new_mit_mot_out_slices[mitmot_idx].remove(inp_tap)
else:
# Wrap the corresponding input as usual. Leave the
# output as-is.
wrapped_inputs.append(In(self.inputs[input_idx],
borrow=False))
input_idx += 1
# Wrap the inputs not associated to mitmots and wrap the remaining
# outputs
wrapped_inputs += [In(x, borrow=False) for x in
self.inputs[input_idx:]]
wrapped_outputs = [Out(x, borrow=True) for x in
new_outputs[:slices]]
wrapped_outputs += new_outputs[slices:]
# Delete the outputs that have are not needed anymore (start from
# the last so as not to alter the position of other outputs that
# need to be deleted)
for out_idx in useless_outputs[::-1]:
del wrapped_outputs[out_idx]
# Store the list of mitmot output taps that the compiled thunk
# actually uses.
self.thunk_mit_mot_out_slices = new_mit_mot_out_slices
self.thunk_mit_mot_outs = sum([len(m)
for m in new_mit_mot_out_slices])
"""
wrapped_inputs = [Param(x, borrow=False) for x in
self.inputs]
wrapped_outputs = [Out(x, borrow=True) for x in
self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:]
"""
else:
# Without output preallocation, there is no manipulation of the
# mitmot output taps. Hence, the output taps used by the compiled
# thunk are the same as self.mit_mot_out_slices
self.thunk_mit_mot_out_slices = self.mit_mot_out_slices
self.thunk_mit_mot_outs = self.n_mit_mot_outs
wrapped_inputs = [Param(x, borrow=True) for x in
self.inputs]
wrapped_outputs = [Out(x, borrow=False) for x in
self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:]
wrapped_outputs += self.outputs[slices:]
profile = None
if (theano.config.profile or
(isinstance(self.profile, (string_types, bool, int))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论