提交 952b05e5 authored 作者: --global's avatar --global

Prevent only {m,s,n}itsots outputs from being computed through inplace ops

上级 db41bffb
...@@ -68,6 +68,7 @@ from six.moves import xrange ...@@ -68,6 +68,7 @@ from six.moves import xrange
import theano import theano
from theano.compat import exc_message from theano.compat import exc_message
from theano.compile import function, In, Param, Out from theano.compile import function, In, Param, Out
from theano.compile.mode import AddNoOutputFromInplace
from theano import compile, config, gradient, gof, tensor from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply from theano.gof import PureOp, Apply
from theano.gof.graph import io_connection_pattern from theano.gof.graph import io_connection_pattern
...@@ -811,13 +812,15 @@ class Scan(PureOp): ...@@ -811,13 +812,15 @@ class Scan(PureOp):
self.mitmots_preallocated = [i in preallocated_mitmot_outs self.mitmots_preallocated = [i in preallocated_mitmot_outs
for i in range(self.n_mit_mot_outs)] for i in range(self.n_mit_mot_outs)]
""" # Add an optimization to the compilation mode to prevent mitsot,
wrapped_inputs = [Param(x, borrow=False) for x in # sitsot and nitsot outputs from being computed inplace (to allow
self.inputs] # their preallocation)
wrapped_outputs = [Out(x, borrow=True) for x in mitsot_start = self.n_mit_mot_outs - len(preallocated_mitmot_outs)
self.outputs[:slices]] nitsot_end = (mitsot_start + self.n_mit_sot + self.n_sit_sot +
wrapped_outputs += self.outputs[slices:] self.n_nit_sot)
""" no_inplace_opt = AddNoOutputFromInplace(0, 4)
compilation_mode = self.mode_instance.register(no_inplace_opt)
else: else:
# Output preallocation is not activated. Mark every mitmot output # Output preallocation is not activated. Mark every mitmot output
# tap as not being preallocated # tap as not being preallocated
...@@ -829,6 +832,8 @@ class Scan(PureOp): ...@@ -829,6 +832,8 @@ class Scan(PureOp):
self.outputs[:slices]] self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:] wrapped_outputs += self.outputs[slices:]
compilation_mode = self.mode_instance
profile = None profile = None
if (theano.config.profile or if (theano.config.profile or
(isinstance(self.profile, (string_types, bool, int)) (isinstance(self.profile, (string_types, bool, int))
...@@ -844,7 +849,7 @@ class Scan(PureOp): ...@@ -844,7 +849,7 @@ class Scan(PureOp):
if not getattr(self, 'fn', None): if not getattr(self, 'fn', None):
self.fn = function(wrapped_inputs, self.fn = function(wrapped_inputs,
wrapped_outputs, wrapped_outputs,
mode=self.mode_instance, mode=compilation_mode,
name=self.name, name=self.name,
profile=profile, profile=profile,
on_unused_input='ignore') on_unused_input='ignore')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论