提交 08ffec41 authored 作者: carriepl's avatar carriepl

Generalize AddNoOutputFromInplace to AddFeatureOptimizer

上级 cc2abb87
...@@ -161,22 +161,16 @@ class AddDestroyHandler(gof.Optimizer): ...@@ -161,22 +161,16 @@ class AddDestroyHandler(gof.Optimizer):
fgraph.attach_feature(gof.DestroyHandler()) fgraph.attach_feature(gof.DestroyHandler())
class AddNoOutputFromInplace(gof.Optimizer): class AddFeatureOptimizer(gof.Optimizer):
""" """
This optimizer adds to the fgraph a feature that will prevent outputs This optimizer adds a provided feature to the function graph.
of a fgraph to be created by performing inplace operations on intermediary
variables. This is useful when the outputs of the fgraph are preallocated
to prevent useless copying of the data. Currently, scan preallocates its
outputs
""" """
def __init__(self, first_output_idx=0, last_output_idx=None): def __init__(self, feature):
self.feature = gof.NoOutputFromInplace(first_output_idx, self.feature = feature
last_output_idx)
def add_requirements(self, fgraph): def add_requirements(self, fgraph):
super(AddNoOutputFromInplace, self).add_requirements(fgraph) super(AddFeatureOptimizer, self).add_requirements(fgraph)
fgraph.attach_feature(self.feature) fgraph.attach_feature(self.feature)
...@@ -234,9 +228,6 @@ optdb.register('specialize_device', gof.EquilibriumDB(), ...@@ -234,9 +228,6 @@ optdb.register('specialize_device', gof.EquilibriumDB(),
optdb.register('merge2', gof.MergeOptimizer(), optdb.register('merge2', gof.MergeOptimizer(),
49, 'fast_run', 'merge') 49, 'fast_run', 'merge')
optdb.register('add_no_output_from_inplace', AddNoOutputFromInplace(),
49.4)
optdb.register('add_destroy_handler', AddDestroyHandler(), optdb.register('add_destroy_handler', AddDestroyHandler(),
49.5, 'fast_run', 'inplace') 49.5, 'fast_run', 'inplace')
......
...@@ -68,10 +68,11 @@ from six.moves import xrange ...@@ -68,10 +68,11 @@ from six.moves import xrange
import theano import theano
from theano.compat import exc_message from theano.compat import exc_message
from theano.compile import function, In, Param, Out from theano.compile import function, In, Param, Out
from theano.compile.mode import AddNoOutputFromInplace from theano.compile.mode import AddFeatureOptimizer
from theano import compile, config, gradient, gof, tensor from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply from theano.gof import PureOp, Apply
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern
from theano.gof.toolbox import NoOutputFromInplace
from theano.compat import OrderedDict, izip from theano.compat import OrderedDict, izip
from theano.tensor import TensorType from theano.tensor import TensorType
from theano.tensor.opt import Shape_i from theano.tensor.opt import Shape_i
...@@ -802,16 +803,18 @@ class Scan(PureOp): ...@@ -802,16 +803,18 @@ class Scan(PureOp):
self.mitmots_preallocated = [i in preallocated_mitmot_outs self.mitmots_preallocated = [i in preallocated_mitmot_outs
for i in range(self.n_mit_mot_outs)] for i in range(self.n_mit_mot_outs)]
# Add an optimization to the compilation mode to prevent mitsot, # Add an optimization to the compilation mode to attach a feature
# sitsot and nitsot outputs from being computed inplace (to allow # to the function graph just before the inplace optimizations are
# their preallocation). This optimization is added such that it # applied. This feature will prevent mitsot, sitsot and nitsot
# will run just before the inplace optimizations # outputs from being computed inplace (to allow their
# preallocation).
mitsot_start = self.n_mit_mot_outs - len(preallocated_mitmot_outs) mitsot_start = self.n_mit_mot_outs - len(preallocated_mitmot_outs)
nitsot_end = (mitsot_start + self.n_mit_sot + self.n_sit_sot + nitsot_end = (mitsot_start + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot) self.n_nit_sot)
no_inplace_opt = AddNoOutputFromInplace(mitsot_start, nitsot_end)
compilation_mode = self.mode_instance.register((no_inplace_opt, feature = NoOutputFromInplace(mitsot_start, nitsot_end)
49.9)) opt = AddFeatureOptimizer(feature)
compilation_mode = self.mode_instance.register((opt, 49.9))
else: else:
# Output preallocation is not activated. Mark every mitmot output # Output preallocation is not activated. Mark every mitmot output
...@@ -1691,7 +1694,7 @@ class Scan(PureOp): ...@@ -1691,7 +1694,7 @@ class Scan(PureOp):
return connection_pattern return connection_pattern
def get_oinp_iinp_iout_oout_mappings(self): def get_oinp_iinp_iout_oout_mappings(self):
""" """
Compute and return dictionary mappings between the inputs and Compute and return dictionary mappings between the inputs and
outputs of the inner function and the inputs and outputs of the Scan outputs of the inner function and the inputs and outputs of the Scan
node in the outer graph. node in the outer graph.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论