提交 5d051ce9 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a separate allow_gc flag for the Scans.

上级 26a22806
......@@ -9,6 +9,8 @@ from theano.gof import utils
from theano.gof import graph
from theano.gof.type import Type
from .utils import MethodNotDefined, undef
__excepthook = sys.excepthook
......@@ -181,6 +183,12 @@ raise_with_op.print_thunk_trace = False
class Linker(object):
"""WRITEME"""
def clone(allow_gc=undef):
new = copy(self)
if allow_gc is not undef:
new.allow_gc = allow_gc
return new
def make_thunk(self):
"""
This function must return a triplet (function, input_variables, output_variables)
......@@ -689,6 +697,11 @@ class WrapLinker(Linker):
wrapper=self.wrapper)
return other
def clone(allow_gc=undef):
return self.__class__(
linkers=[l.clone(allow_gc=allow_gc)],
wrapper=self.wrapper)
def accept(self, fgraph, no_recycling=None):
"""
@type fgraph: gof.FunctionGraph
......
......@@ -39,6 +39,11 @@ def hashtype(self):
return hash(t.__name__) ^ hash(t.__module__)
# Object to mark that a parameter is undefined (useful in cases where
# None is a valid value with defined semantics)
undef = object()
class MethodNotDefined(Exception):
"""
To be raised by functions defined as part of an interface.
......
......@@ -22,17 +22,12 @@ import numpy
import theano
from theano.compat import exc_message
from theano.compile import function, Param, Out
from theano import compile
from theano import gradient
from theano.gof.python25 import any, OrderedDict
from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply
from theano import gof
from theano.gof.python25 import any, OrderedDict
from theano.tensor import TensorType
from theano import tensor
from theano.tensor.opt import Shape_i
from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType
from theano.gradient import NullType
from theano.gradient import grad_undefined, DisconnectedType, NullType
from theano.compile.profiling import ScanProfileStats
from theano.scan_module import scan_utils
......@@ -42,11 +37,18 @@ from theano.scan_module.scan_utils import safe_new, forced_replace
_logger = logging.getLogger('theano.scan_module.scan_op')
from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('scan.allow_gc',
"Allow/disallow gc inside of Scan",
BoolParam(False))
class Scan(PureOp):
def __init__(self,
inputs,
outputs,
info,
allow_gc=None
):
"""
:param inputs: inputs of the inner function of scan
......@@ -55,9 +57,13 @@ class Scan(PureOp):
the scan op (like number of different types of
arguments, name, mode, if it should run on GPU or
not, etc.)
:param allow_gc: Use the gc in the inner function or not
(independant of the outer function)
"""
if 'gpua' not in info:
info['gpua'] = False
if allow_gc is None:
allow_gc = config.scan.allow_gc
# adding properties into self
self.inputs = inputs
self.outputs = outputs
......@@ -113,6 +119,9 @@ class Scan(PureOp):
else:
self.mode_instance.message = "Scan sub profile"
else:
mode_instance = mode_instance.__type__(
optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker.clone(allow_gc=allow_gc))
self.mode_instance = mode_instance
if not hasattr(self, 'name') or self.name is None:
......@@ -426,10 +435,10 @@ class Scan(PureOp):
if not 'destroy_map' in other.info:
other.info['destroy_map'] = OrderedDict()
keys_to_check = ['truncate_gradient', 'profile',
'n_seqs', 'tap_array', 'name',
'n_seqs', 'tap_array',
'as_while', 'n_mit_sot', 'destroy_map',
'n_nit_sot', 'n_shared_outs',
'n_sit_sot', 'gpu', 'n_mit_mot_outs',
'n_sit_sot', 'gpu', 'gpua', 'n_mit_mot_outs',
'n_mit_mot', 'mit_mot_out_slices']
# This are some safety checks ( namely that the inner graph has the
# same number of inputs and same number of outputs )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论