提交 c07dcdb5 authored 作者: Kyung Hyun Cho's avatar Kyung Hyun Cho

strict mode to avoid nested scan

上级 27ba127c
...@@ -45,6 +45,7 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>" ...@@ -45,6 +45,7 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import itertools import itertools
import logging import logging
import numpy import numpy
import warnings
from theano.compile import SharedVariable, function from theano.compile import SharedVariable, function
from theano import compile from theano import compile
...@@ -75,7 +76,8 @@ def scan(fn, ...@@ -75,7 +76,8 @@ def scan(fn,
mode=None, mode=None,
name=None, name=None,
profile=False, profile=False,
allow_gc=None): allow_gc=None,
strict=False):
""" """
This function constructs and applies a Scan op to the provided This function constructs and applies a Scan op to the provided
arguments. arguments.
...@@ -314,6 +316,10 @@ def scan(fn, ...@@ -314,6 +316,10 @@ def scan(fn,
Set the value of allow gc for the internal graph of scan. If Set the value of allow gc for the internal graph of scan. If
set to None, this will use the value of config.scan.allow_gc. set to None, this will use the value of config.scan.allow_gc.
:param strict:
If true, all the shared variables used in ``fn`` must be provided as a
part of ``non_sequences``.
:rtype: tuple :rtype: tuple
:return: tuple of the form (outputs, updates); ``outputs`` is either a :return: tuple of the form (outputs, updates); ``outputs`` is either a
Theano variable or a list of Theano variables representing the Theano variable or a list of Theano variables representing the
...@@ -910,6 +916,21 @@ def scan(fn, ...@@ -910,6 +916,21 @@ def scan(fn,
not isinstance(arg, tensor.Constant))] not isinstance(arg, tensor.Constant))]
givens.update(OrderedDict(zip(other_scan_args, other_inner_args))) givens.update(OrderedDict(zip(other_scan_args, other_inner_args)))
if strict:
non_seqs_set = set(non_sequences)
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update and
arg.variable in non_seqs_set)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update and
arg.variable in non_seqs_set)]
else:
other_shared_scan_args = [arg.variable for arg other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and if (isinstance(arg.variable, SharedVariable) and
...@@ -990,6 +1011,10 @@ def scan(fn, ...@@ -990,6 +1011,10 @@ def scan(fn,
info['as_while'] = as_while info['as_while'] = as_while
info['profile'] = profile info['profile'] = profile
info['allow_gc'] = allow_gc info['allow_gc'] = allow_gc
info['strict'] = strict
if strict:
warnings.warn('In the strict mode, all neccessary shared variables '
'must be passed as a part of non_sequences', Warning)
local_op = scan_op.Scan(inner_inputs, new_outs, info) local_op = scan_op.Scan(inner_inputs, new_outs, info)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论