提交 08ffec41 authored 作者: carriepl's avatar carriepl

Generalize AddNoOutputFromInplace to AddFeatureOptimizer

上级 cc2abb87
......@@ -161,22 +161,16 @@ class AddDestroyHandler(gof.Optimizer):
fgraph.attach_feature(gof.DestroyHandler())
class AddNoOutputFromInplace(gof.Optimizer):
class AddFeatureOptimizer(gof.Optimizer):
"""
This optimizer adds to the fgraph a feature that will prevent outputs
of a fgraph to be created by performing inplace operations on intermediary
variables. This is useful when the outputs of the fgraph are preallocated
to prevent useless copying of the data. Currently, scan preallocates its
outputs
This optimizer adds a provided feature to the function graph.
"""
def __init__(self, first_output_idx=0, last_output_idx=None):
self.feature = gof.NoOutputFromInplace(first_output_idx,
last_output_idx)
def __init__(self, feature):
self.feature = feature
def add_requirements(self, fgraph):
super(AddNoOutputFromInplace, self).add_requirements(fgraph)
super(AddFeatureOptimizer, self).add_requirements(fgraph)
fgraph.attach_feature(self.feature)
......@@ -234,9 +228,6 @@ optdb.register('specialize_device', gof.EquilibriumDB(),
optdb.register('merge2', gof.MergeOptimizer(),
49, 'fast_run', 'merge')
optdb.register('add_no_output_from_inplace', AddNoOutputFromInplace(),
49.4)
optdb.register('add_destroy_handler', AddDestroyHandler(),
49.5, 'fast_run', 'inplace')
......
......@@ -68,10 +68,11 @@ from six.moves import xrange
import theano
from theano.compat import exc_message
from theano.compile import function, In, Param, Out
from theano.compile.mode import AddNoOutputFromInplace
from theano.compile.mode import AddFeatureOptimizer
from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply
from theano.gof.graph import io_connection_pattern
from theano.gof.toolbox import NoOutputFromInplace
from theano.compat import OrderedDict, izip
from theano.tensor import TensorType
from theano.tensor.opt import Shape_i
......@@ -802,16 +803,18 @@ class Scan(PureOp):
self.mitmots_preallocated = [i in preallocated_mitmot_outs
for i in range(self.n_mit_mot_outs)]
# Add an optimization to the compilation mode to prevent mitsot,
# sitsot and nitsot outputs from being computed inplace (to allow
# their preallocation). This optimization is added such that it
# will run just before the inplace optimizations
# Add an optimization to the compilation mode to attach a feature
# to the function graph just before the inplace optimizations are
# applied. This feature will prevent mitsot, sitsot and nitsot
# outputs from being computed inplace (to allow their
# preallocation).
mitsot_start = self.n_mit_mot_outs - len(preallocated_mitmot_outs)
nitsot_end = (mitsot_start + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot)
no_inplace_opt = AddNoOutputFromInplace(mitsot_start, nitsot_end)
compilation_mode = self.mode_instance.register((no_inplace_opt,
49.9))
feature = NoOutputFromInplace(mitsot_start, nitsot_end)
opt = AddFeatureOptimizer(feature)
compilation_mode = self.mode_instance.register((opt, 49.9))
else:
# Output preallocation is not activated. Mark every mitmot output
......@@ -1691,7 +1694,7 @@ class Scan(PureOp):
return connection_pattern
def get_oinp_iinp_iout_oout_mappings(self):
"""
"""
Compute and return dictionary mappings between the inputs and
outputs of the inner function and the inputs and outputs of the Scan
node in the outer graph.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论