提交 952b05e5 authored 作者: --global's avatar --global

Prevent only {m,s,n}itsots outputs from being computed through inplace ops

上级 db41bffb
......@@ -68,6 +68,7 @@ from six.moves import xrange
import theano
from theano.compat import exc_message
from theano.compile import function, In, Param, Out
from theano.compile.mode import AddNoOutputFromInplace
from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply
from theano.gof.graph import io_connection_pattern
......@@ -811,13 +812,15 @@ class Scan(PureOp):
self.mitmots_preallocated = [i in preallocated_mitmot_outs
for i in range(self.n_mit_mot_outs)]
"""
wrapped_inputs = [Param(x, borrow=False) for x in
self.inputs]
wrapped_outputs = [Out(x, borrow=True) for x in
self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:]
"""
# Add an optimization to the compilation mode to prevent mitsot,
# sitsot and nitsot outputs from being computed inplace (to allow
# their preallocation)
mitsot_start = self.n_mit_mot_outs - len(preallocated_mitmot_outs)
nitsot_end = (mitsot_start + self.n_mit_sot + self.n_sit_sot +
self.n_nit_sot)
no_inplace_opt = AddNoOutputFromInplace(0, 4)
compilation_mode = self.mode_instance.register(no_inplace_opt)
else:
# Output preallocation is not activated. Mark every mitmot output
# tap as not being preallocated
......@@ -829,6 +832,8 @@ class Scan(PureOp):
self.outputs[:slices]]
wrapped_outputs += self.outputs[slices:]
compilation_mode = self.mode_instance
profile = None
if (theano.config.profile or
(isinstance(self.profile, (string_types, bool, int))
......@@ -844,7 +849,7 @@ class Scan(PureOp):
if not getattr(self, 'fn', None):
self.fn = function(wrapped_inputs,
wrapped_outputs,
mode=self.mode_instance,
mode=compilation_mode,
name=self.name,
profile=profile,
on_unused_input='ignore')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论