提交 8a5d41da authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #2336 from abergeron/scan_allow_gc

Separate allow_gc flag for scan
......@@ -1176,20 +1176,22 @@ class FunctionMaker(object):
# 1) We preload the cache here to don't have its timming
# included in optimization that compile function.
# 2) If other repo that import Theano have Theano ops defined,
# we need to refresh the cache here. Otherwise, their is import
# we need to refresh the cache here. Otherwise, there are import
# order problems.
# When device=gpu, we compile during Theano import. This trigger
# the loading of the cache. But unpickling the cache ask that the
# other repos Ops are completly loaded, which isn't always the
# case!
# If a module isn't completly loaded and their unpickling fail,
# it mean it is safe for this function compilation to skip them,
# but not for futur compilation. So reloading the cache at each
# compilation fix this problem.
# 3) This help propagate knowledge of newly compiled module to
# concurrent process.
# When device=gpu, we compile during Theano
# import. This triggers the loading of the cache. But
# unpickling the cache asks that the external Ops are
# completly loaded, which isn't always the case!
# If a module isn't completly loaded and its unpickling
# fails, it means it is safe for this function
# compilation to skip them, but not for future
# compilations. So reloading the cache at each
# compilation fixes this problem.
# 3) This helps propagate knowledge of newly compiled modules to
# concurrent processes.
theano.gof.cc.get_module_cache().refresh()
# Handle the case where inputs and/or outputs is a single Variable (not in a list)
# Handle the case where inputs and/or outputs is a single
# Variable (not in a list)
self.orig_outputs = outputs
unpack_single = False
return_none = False
......
......@@ -9,6 +9,8 @@ from theano.gof import utils
from theano.gof import graph
from theano.gof.type import Type
from .utils import MethodNotDefined, undef
__excepthook = sys.excepthook
......@@ -181,6 +183,12 @@ raise_with_op.print_thunk_trace = False
class Linker(object):
"""WRITEME"""
def clone(self, allow_gc=undef):
new = copy(self)
if allow_gc is not undef:
new.allow_gc = allow_gc
return new
def make_thunk(self):
"""
This function must return a triplet (function, input_variables, output_variables)
......@@ -689,6 +697,11 @@ class WrapLinker(Linker):
wrapper=self.wrapper)
return other
def clone(allow_gc=undef):
return self.__class__(
linkers=[l.clone(allow_gc=allow_gc)],
wrapper=self.wrapper)
def accept(self, fgraph, no_recycling=None):
"""
@type fgraph: gof.FunctionGraph
......
......@@ -39,6 +39,11 @@ def hashtype(self):
return hash(t.__name__) ^ hash(t.__module__)
# Object to mark that a parameter is undefined (useful in cases where
# None is a valid value with defined semantics)
undef = object()
class MethodNotDefined(Exception):
"""
To be raised by functions defined as part of an interface.
......
......@@ -40,7 +40,8 @@ def scan(fn,
n_steps=None,
mode=None,
name=None,
profile=False):
profile=False,
allow_gc=None):
"""
Similar to Theano's official scan, this function gives the user more
control over the scan op, avoiding certain difficulties that arose from
......@@ -115,6 +116,8 @@ def scan(fn,
seqs = wrap_into_list(sequences)
outs_info = wrap_into_list(states)
if allow_gc is None:
allow_gc = config.scan.allow_gc
# Make sure we get rid of numpy arrays or ints or anything like that
# passed as inputs to scan
......@@ -627,6 +630,7 @@ def scan(fn,
info['as_while'] = as_while
info['profile'] = profile
info['_scan_savemem_visited'] = True
info['allow_gc'] = allow_gc
local_op = scan_op.Scan(inner_inputs, new_outs, info)
......
......@@ -74,7 +74,8 @@ def scan(fn,
go_backwards=False,
mode=None,
name=None,
profile=False):
profile=False,
allow_gc=None):
"""
This function constructs and applies a Scan op to the provided
arguments.
......@@ -309,6 +310,10 @@ def scan(fn,
inner graph with the new cvm linker ( with default modes,
other linkers this argument is useless)
:param allow_gc:
Set the value of allow gc for the internal graph of scan. If
set to None, this will use the value of config.scan.allow_gc.
:rtype: tuple
:return: tuple of the form (outputs, updates); ``outputs`` is either a
Theano variable or a list of Theano variables representing the
......@@ -962,6 +967,8 @@ def scan(fn,
##
tap_array = mit_sot_tap_array + [[-1] for x in xrange(n_sit_sot)]
if allow_gc is None:
allow_gc = config.scan.allow_gc
info = OrderedDict()
info['tap_array'] = tap_array
......@@ -980,6 +987,7 @@ def scan(fn,
info['gpu'] = False
info['as_while'] = as_while
info['profile'] = profile
info['allow_gc'] = allow_gc
local_op = scan_op.Scan(inner_inputs, new_outs, info)
......
......@@ -22,17 +22,12 @@ import numpy
import theano
from theano.compat import exc_message
from theano.compile import function, Param, Out
from theano import compile
from theano import gradient
from theano.gof.python25 import any, OrderedDict
from theano import compile, config, gradient, gof, tensor
from theano.gof import PureOp, Apply
from theano import gof
from theano.gof.python25 import any, OrderedDict
from theano.tensor import TensorType
from theano import tensor
from theano.tensor.opt import Shape_i
from theano.gradient import grad_undefined
from theano.gradient import DisconnectedType
from theano.gradient import NullType
from theano.gradient import grad_undefined, DisconnectedType, NullType
from theano.compile.profiling import ScanProfileStats
from theano.scan_module import scan_utils
......@@ -42,6 +37,13 @@ from theano.scan_module.scan_utils import safe_new, forced_replace
_logger = logging.getLogger('theano.scan_module.scan_op')
from theano.configparser import AddConfigVar, BoolParam
AddConfigVar('scan.allow_gc',
"Allow/disallow gc inside of Scan (default: config.allow_gc)",
BoolParam(lambda: config.allow_gc))
class Scan(PureOp):
def __init__(self,
inputs,
......@@ -66,7 +68,6 @@ class Scan(PureOp):
# since info contains all tunable parameters of the op, so for two
# scan to be equal this tunable parameters should be the same
self.info = info
# build a list of output types for any Apply node using this op.
self.output_types = []
idx = 0
......@@ -104,7 +105,7 @@ class Scan(PureOp):
isinstance(mode_instance, compile.profilemode.ProfileMode)):
mode_instance = compile.profilemode.ProfileMode(
optimizer=mode_instance.provided_optimizer,
linker=mode_instance.provided_linker)
linker=mode_instance.linker.clone(allow_gc=self.allow_gc))
compile.profilemode.prof_mode_instance_to_print.append(
mode_instance)
self.mode_instance = mode_instance
......@@ -113,7 +114,9 @@ class Scan(PureOp):
else:
self.mode_instance.message = "Scan sub profile"
else:
self.mode_instance = mode_instance
self.mode_instance = type(mode_instance)(
optimizer=mode_instance.provided_optimizer,
linker=mode_instance.linker.clone(allow_gc=self.allow_gc))
if not hasattr(self, 'name') or self.name is None:
self.name = 'scan_fn'
......@@ -426,10 +429,10 @@ class Scan(PureOp):
if not 'destroy_map' in other.info:
other.info['destroy_map'] = OrderedDict()
keys_to_check = ['truncate_gradient', 'profile',
'n_seqs', 'tap_array', 'name',
'n_seqs', 'tap_array',
'as_while', 'n_mit_sot', 'destroy_map',
'n_nit_sot', 'n_shared_outs',
'n_sit_sot', 'gpu', 'n_mit_mot_outs',
'n_sit_sot', 'gpu', 'gpua', 'n_mit_mot_outs',
'n_mit_mot', 'mit_mot_out_slices']
# This are some safety checks ( namely that the inner graph has the
# same number of inputs and same number of outputs )
......@@ -447,15 +450,11 @@ class Scan(PureOp):
if self_in.type != other_in.type:
return False
if not scan_utils.equal_computations(self.outputs,
return scan_utils.equal_computations(self.outputs,
other.outputs,
self.inputs,
other.inputs):
return False
other.inputs)
# If they do, then they need to match in other small details
# like name, mode, etc.
return True
def __str__(self):
if self.gpu:
......@@ -623,10 +622,17 @@ class Scan(PureOp):
p = self.execute
# default arguments are stored in the closure of `rval`
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node):
# Big ugly hack since we can't get the real value of allow_gc
# for the englobing function.
allow_gc = config.allow_gc and not self.allow_gc
def rval(p=p, i=node_input_storage, o=node_output_storage, n=node,
allow_gc=allow_gc):
r = p(n, [x[0] for x in i], o)
for o in node.outputs:
compute_map[o][0] = True
if allow_gc:
self.fn.free()
return r
rval.inputs = node_input_storage
rval.outputs = node_output_storage
......@@ -1876,6 +1882,7 @@ class Scan(PureOp):
else:
info['name'] = None
info['mode'] = self.mode
info['allow_gc'] = self.allow_gc
outer_inputs = ([grad_steps] +
outer_inp_seqs +
......@@ -2041,6 +2048,7 @@ class Scan(PureOp):
else:
info['name'] = None
info['mode'] = self.mode
info['allow_gc'] = self.allow_gc
info['mit_mot_out_slices'] = self.mit_mot_out_slices * 2
info['destroy_map'] = OrderedDict()
new_tap_array = []
......
......@@ -1500,6 +1500,7 @@ class ScanMerge(gof.Optimizer):
info['gpu'] = False
info['as_while'] = as_while
info['profile'] = nodes[0].op.profile
info['allow_gc'] = nodes[0].op.allow_gc
# We keep the inner_ins and inner_outs of each original node separated.
# To be able to recombine them in the right order after the clone,
......
......@@ -654,9 +654,11 @@ def compress_outs(op, not_required, inputs):
info['truncate_gradient'] = op.info['truncate_gradient']
info['name'] = op.info['name']
info['gpu'] = op.info['gpu']
info['gpua'] = op.info['gpua']
info['mode'] = op.info['mode']
info['as_while'] = op.info['as_while']
info['profile'] = op.info['profile']
info['allow_gc'] = op.info['allow_gc']
op_inputs = op.inputs[:op.n_seqs]
op_outputs = []
......@@ -919,7 +921,7 @@ class scan_args(object):
self.other_info = OrderedDict()
for k in ('truncate_gradient', 'name', 'mode', 'destroy_map',
'gpu', 'as_while', 'profile'):
'gpu', 'gpua', 'as_while', 'profile', 'allow_gc'):
if k in info:
self.other_info[k] = info[k]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论