提交 5d051ce9 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a separate allow_gc flag for the Scans.

上级 26a22806
...@@ -9,6 +9,8 @@ from theano.gof import utils ...@@ -9,6 +9,8 @@ from theano.gof import utils
from theano.gof import graph from theano.gof import graph
from theano.gof.type import Type from theano.gof.type import Type
from .utils import MethodNotDefined, undef
__excepthook = sys.excepthook __excepthook = sys.excepthook
...@@ -181,6 +183,12 @@ raise_with_op.print_thunk_trace = False ...@@ -181,6 +183,12 @@ raise_with_op.print_thunk_trace = False
class Linker(object): class Linker(object):
"""WRITEME""" """WRITEME"""
def clone(allow_gc=undef):
new = copy(self)
if allow_gc is not undef:
new.allow_gc = allow_gc
return new
def make_thunk(self): def make_thunk(self):
""" """
This function must return a triplet (function, input_variables, output_variables) This function must return a triplet (function, input_variables, output_variables)
...@@ -689,6 +697,11 @@ class WrapLinker(Linker): ...@@ -689,6 +697,11 @@ class WrapLinker(Linker):
wrapper=self.wrapper) wrapper=self.wrapper)
return other return other
def clone(allow_gc=undef):
return self.__class__(
linkers=[l.clone(allow_gc=allow_gc)],
wrapper=self.wrapper)
def accept(self, fgraph, no_recycling=None): def accept(self, fgraph, no_recycling=None):
""" """
@type fgraph: gof.FunctionGraph @type fgraph: gof.FunctionGraph
......
...@@ -39,6 +39,11 @@ def hashtype(self): ...@@ -39,6 +39,11 @@ def hashtype(self):
return hash(t.__name__) ^ hash(t.__module__) return hash(t.__name__) ^ hash(t.__module__)
# Object to mark that a parameter is undefined (useful in cases where
# None is a valid value with defined semantics)
undef = object()
class MethodNotDefined(Exception): class MethodNotDefined(Exception):
""" """
To be raised by functions defined as part of an interface. To be raised by functions defined as part of an interface.
......
...@@ -22,17 +22,12 @@ import numpy ...@@ -22,17 +22,12 @@ import numpy
import theano import theano
from theano.compat import exc_message from theano.compat import exc_message
from theano.compile import function, Param, Out from theano.compile import function, Param, Out
from theano import compile from theano import compile, config, gradient, gof, tensor
from theano import gradient
from theano.gof.python25 import any, OrderedDict
from theano.gof import PureOp, Apply from theano.gof import PureOp, Apply
from theano import gof from theano.gof.python25 import any, OrderedDict
from theano.tensor import TensorType from theano.tensor import TensorType
from theano import tensor
from theano.tensor.opt import Shape_i from theano.tensor.opt import Shape_i
from theano.gradient import grad_undefined from theano.gradient import grad_undefined, DisconnectedType, NullType
from theano.gradient import DisconnectedType
from theano.gradient import NullType
from theano.compile.profiling import ScanProfileStats from theano.compile.profiling import ScanProfileStats
from theano.scan_module import scan_utils from theano.scan_module import scan_utils
...@@ -42,11 +37,18 @@ from theano.scan_module.scan_utils import safe_new, forced_replace ...@@ -42,11 +37,18 @@ from theano.scan_module.scan_utils import safe_new, forced_replace
_logger = logging.getLogger('theano.scan_module.scan_op') _logger = logging.getLogger('theano.scan_module.scan_op')
from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('scan.allow_gc',
"Allow/disallow gc inside of Scan",
BoolParam(False))
class Scan(PureOp): class Scan(PureOp):
def __init__(self, def __init__(self,
inputs, inputs,
outputs, outputs,
info, info,
allow_gc=None
): ):
""" """
:param inputs: inputs of the inner function of scan :param inputs: inputs of the inner function of scan
...@@ -55,9 +57,13 @@ class Scan(PureOp): ...@@ -55,9 +57,13 @@ class Scan(PureOp):
the scan op (like number of different types of the scan op (like number of different types of
arguments, name, mode, if it should run on GPU or arguments, name, mode, if it should run on GPU or
not, etc.) not, etc.)
:param allow_gc: Use the gc in the inner function or not
(independant of the outer function)
""" """
if 'gpua' not in info: if 'gpua' not in info:
info['gpua'] = False info['gpua'] = False
if allow_gc is None:
allow_gc = config.scan.allow_gc
# adding properties into self # adding properties into self
self.inputs = inputs self.inputs = inputs
self.outputs = outputs self.outputs = outputs
...@@ -113,6 +119,9 @@ class Scan(PureOp): ...@@ -113,6 +119,9 @@ class Scan(PureOp):
else: else:
self.mode_instance.message = "Scan sub profile" self.mode_instance.message = "Scan sub profile"
else: else:
mode_instance = mode_instance.__type__(
optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker.clone(allow_gc=allow_gc))
self.mode_instance = mode_instance self.mode_instance = mode_instance
if not hasattr(self, 'name') or self.name is None: if not hasattr(self, 'name') or self.name is None:
...@@ -426,10 +435,10 @@ class Scan(PureOp): ...@@ -426,10 +435,10 @@ class Scan(PureOp):
if not 'destroy_map' in other.info: if not 'destroy_map' in other.info:
other.info['destroy_map'] = OrderedDict() other.info['destroy_map'] = OrderedDict()
keys_to_check = ['truncate_gradient', 'profile', keys_to_check = ['truncate_gradient', 'profile',
'n_seqs', 'tap_array', 'name', 'n_seqs', 'tap_array',
'as_while', 'n_mit_sot', 'destroy_map', 'as_while', 'n_mit_sot', 'destroy_map',
'n_nit_sot', 'n_shared_outs', 'n_nit_sot', 'n_shared_outs',
'n_sit_sot', 'gpu', 'n_mit_mot_outs', 'n_sit_sot', 'gpu', 'gpua', 'n_mit_mot_outs',
'n_mit_mot', 'mit_mot_out_slices'] 'n_mit_mot', 'mit_mot_out_slices']
# This are some safety checks ( namely that the inner graph has the # This are some safety checks ( namely that the inner graph has the
# same number of inputs and same number of outputs ) # same number of inputs and same number of outputs )
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论