提交 c07dcdb5 authored 作者: Kyung Hyun Cho's avatar Kyung Hyun Cho

strict mode to avoid nested scan

上级 27ba127c
......@@ -45,6 +45,7 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import itertools
import logging
import numpy
import warnings
from theano.compile import SharedVariable, function
from theano import compile
......@@ -75,7 +76,8 @@ def scan(fn,
mode=None,
name=None,
profile=False,
allow_gc=None):
allow_gc=None,
strict=False):
"""
This function constructs and applies a Scan op to the provided
arguments.
......@@ -314,6 +316,10 @@ def scan(fn,
Set the value of allow gc for the internal graph of scan. If
set to None, this will use the value of config.scan.allow_gc.
:param strict:
If true, all the shared variables used in ``fn`` must be provided as a
part of ``non_sequences``.
:rtype: tuple
:return: tuple of the form (outputs, updates); ``outputs`` is either a
Theano variable or a list of Theano variables representing the
......@@ -910,14 +916,29 @@ def scan(fn,
not isinstance(arg, tensor.Constant))]
givens.update(OrderedDict(zip(other_scan_args, other_inner_args)))
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
if strict:
non_seqs_set = set(non_sequences)
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update and
arg.variable in non_seqs_set)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update and
arg.variable in non_seqs_set)]
else:
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
givens.update(OrderedDict(zip(other_shared_scan_args,
other_shared_inner_args)))
......@@ -990,6 +1011,10 @@ def scan(fn,
info['as_while'] = as_while
info['profile'] = profile
info['allow_gc'] = allow_gc
info['strict'] = strict
if strict:
warnings.warn('In the strict mode, all neccessary shared variables '
'must be passed as a part of non_sequences', Warning)
local_op = scan_op.Scan(inner_inputs, new_outs, info)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论