提交 4e91569d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2622 from kyunghyuncho/scan_mode_strict

strict mode to avoid nested scan
......@@ -45,6 +45,7 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import itertools
import logging
import numpy
import warnings
from theano.compile import SharedVariable, function
from theano import compile
......@@ -75,7 +76,8 @@ def scan(fn,
mode=None,
name=None,
profile=False,
allow_gc=None):
allow_gc=None,
strict=False):
"""
This function constructs and applies a Scan op to the provided
arguments.
......@@ -314,6 +316,10 @@ def scan(fn,
Set the value of allow gc for the internal graph of scan. If
set to None, this will use the value of config.scan.allow_gc.
:param strict:
If true, all the shared variables used in ``fn`` must be provided as a
part of ``non_sequences`` or ``sequences``.
:rtype: tuple
:return: tuple of the form (outputs, updates); ``outputs`` is either a
Theano variable or a list of Theano variables representing the
......@@ -910,14 +916,29 @@ def scan(fn,
not isinstance(arg, tensor.Constant))]
givens.update(OrderedDict(zip(other_scan_args, other_inner_args)))
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
if strict:
non_seqs_set = set(non_sequences if non_sequences != None else [])
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update and
arg.variable in non_seqs_set)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update and
arg.variable in non_seqs_set)]
else:
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
givens.update(OrderedDict(zip(other_shared_scan_args,
other_shared_inner_args)))
......@@ -990,6 +1011,10 @@ def scan(fn,
info['as_while'] = as_while
info['profile'] = profile
info['allow_gc'] = allow_gc
info['strict'] = strict
if strict:
warnings.warn('In the strict mode, all neccessary shared variables '
'must be passed as a part of non_sequences', Warning)
local_op = scan_op.Scan(inner_inputs, new_outs, info)
......
......@@ -11,6 +11,7 @@ import numpy
from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr
from nose.tools import assert_raises
from nose.tools import raises
from numpy.testing import dec
import theano
......@@ -46,11 +47,11 @@ else:
mode_with_gpu = mode_with_opt.including('gpu', 'scan')
type_eps = {'float64': 1e-7,
'float32': 3e-3}
class multiple_outputs_numeric_grad:
"""WRITEME"""
type_eps = {'float64': 1e-7,
'float32': 3e-3}
def __init__(self, f, pt, ndarray_mask=None, eps=None):
"""Return the gradient of f at pt.
......@@ -78,13 +79,12 @@ class multiple_outputs_numeric_grad:
if not ndarray_mask:
ndarray_mask = [True for x in pt]
dtype_eps = multiple_outputs_numeric_grad.type_eps['float64']
dtype_eps = type_eps['float64']
for i, p in enumerate(pt):
if ndarray_mask[i]:
pt[i] = numpy.array(p)
_eps = multiple_outputs_numeric_grad.type_eps[str(
pt[i].dtype)]
_eps = type_eps[str(pt[i].dtype)]
if _eps > dtype_eps:
dtype_eps = _eps
......@@ -3980,6 +3980,63 @@ class T_Scan(unittest.TestCase):
f = theano.function([W, n_steps], H)
f(numpy.ones((8,), dtype='float32'), 1)
def test_strict_mode(self):
n = 10
w = numpy.array([[-1,2],[3,-4]]).astype(theano.config.floatX)
w_ = theano.shared(w)
x0 = numpy.array([1,2]).astype(theano.config.floatX)
x0_ = tensor.vector(name='x0', dtype=theano.config.floatX)
def _scan_loose(x):
return tensor.dot(x, w_)
def _scan_strict(x, w_ns):
return tensor.dot(x, w_ns)
ret_loose = theano.scan(_scan_loose,
sequences=[],
outputs_info=[x0_],
n_steps=n,
strict=False)
f_loose = theano.function([x0_], ret_loose[0][-1])
ret_strict = theano.scan(_scan_strict,
sequences=[],
outputs_info=[x0_],
non_sequences=[w_],
n_steps=n,
strict=True)
f_strict = theano.function([x0_], ret_strict[0][-1])
result_loose = f_loose(x0)
result_strict = f_strict(x0)
diff = (abs(result_loose - result_strict)).mean()
assert diff <= type_eps[theano.config.floatX]
@raises(theano.gof.fg.MissingInputError)
def test_strict_mode_ex(self):
n = 10
w = numpy.array([[-1,2],[3,-4]]).astype(theano.config.floatX)
w_ = theano.shared(w)
x0 = numpy.array([1,2]).astype(theano.config.floatX)
x0_ = tensor.vector(name='x0', dtype=theano.config.floatX)
def _scan_loose(x):
return tensor.dot(x, w_)
ret_strict = theano.scan(_scan_loose,
sequences=[],
outputs_info=[x0_],
n_steps=n,
strict=True)
f_strict = theano.function([x0_], ret_strict[0][-1])
result_strict = f_strict(x0)
def test_speed():
#
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论