提交 4e91569d authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Merge pull request #2622 from kyunghyuncho/scan_mode_strict

strict mode to avoid nested scan
...@@ -45,6 +45,7 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>" ...@@ -45,6 +45,7 @@ __contact__ = "Razvan Pascanu <r.pascanu@gmail>"
import itertools import itertools
import logging import logging
import numpy import numpy
import warnings
from theano.compile import SharedVariable, function from theano.compile import SharedVariable, function
from theano import compile from theano import compile
...@@ -75,7 +76,8 @@ def scan(fn, ...@@ -75,7 +76,8 @@ def scan(fn,
mode=None, mode=None,
name=None, name=None,
profile=False, profile=False,
allow_gc=None): allow_gc=None,
strict=False):
""" """
This function constructs and applies a Scan op to the provided This function constructs and applies a Scan op to the provided
arguments. arguments.
...@@ -314,6 +316,10 @@ def scan(fn, ...@@ -314,6 +316,10 @@ def scan(fn,
Set the value of allow gc for the internal graph of scan. If Set the value of allow gc for the internal graph of scan. If
set to None, this will use the value of config.scan.allow_gc. set to None, this will use the value of config.scan.allow_gc.
:param strict:
If true, all the shared variables used in ``fn`` must be provided as a
part of ``non_sequences`` or ``sequences``.
:rtype: tuple :rtype: tuple
:return: tuple of the form (outputs, updates); ``outputs`` is either a :return: tuple of the form (outputs, updates); ``outputs`` is either a
Theano variable or a list of Theano variables representing the Theano variable or a list of Theano variables representing the
...@@ -910,14 +916,29 @@ def scan(fn, ...@@ -910,14 +916,29 @@ def scan(fn,
not isinstance(arg, tensor.Constant))] not isinstance(arg, tensor.Constant))]
givens.update(OrderedDict(zip(other_scan_args, other_inner_args))) givens.update(OrderedDict(zip(other_scan_args, other_inner_args)))
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs if strict:
if (isinstance(arg.variable, SharedVariable) and non_seqs_set = set(non_sequences if non_sequences != None else [])
not arg.update)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and if (isinstance(arg.variable, SharedVariable) and
not arg.update)] not arg.update and
arg.variable in non_seqs_set)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update and
arg.variable in non_seqs_set)]
else:
other_shared_scan_args = [arg.variable for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
other_shared_inner_args = [safe_new(arg.variable, '_copy') for arg
in dummy_f.maker.expanded_inputs
if (isinstance(arg.variable, SharedVariable) and
not arg.update)]
givens.update(OrderedDict(zip(other_shared_scan_args, givens.update(OrderedDict(zip(other_shared_scan_args,
other_shared_inner_args))) other_shared_inner_args)))
...@@ -990,6 +1011,10 @@ def scan(fn, ...@@ -990,6 +1011,10 @@ def scan(fn,
info['as_while'] = as_while info['as_while'] = as_while
info['profile'] = profile info['profile'] = profile
info['allow_gc'] = allow_gc info['allow_gc'] = allow_gc
info['strict'] = strict
if strict:
warnings.warn('In the strict mode, all neccessary shared variables '
'must be passed as a part of non_sequences', Warning)
local_op = scan_op.Scan(inner_inputs, new_outs, info) local_op = scan_op.Scan(inner_inputs, new_outs, info)
......
...@@ -11,6 +11,7 @@ import numpy ...@@ -11,6 +11,7 @@ import numpy
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
from nose.tools import assert_raises from nose.tools import assert_raises
from nose.tools import raises
from numpy.testing import dec from numpy.testing import dec
import theano import theano
...@@ -46,11 +47,11 @@ else: ...@@ -46,11 +47,11 @@ else:
mode_with_gpu = mode_with_opt.including('gpu', 'scan') mode_with_gpu = mode_with_opt.including('gpu', 'scan')
type_eps = {'float64': 1e-7,
'float32': 3e-3}
class multiple_outputs_numeric_grad: class multiple_outputs_numeric_grad:
"""WRITEME""" """WRITEME"""
type_eps = {'float64': 1e-7,
'float32': 3e-3}
def __init__(self, f, pt, ndarray_mask=None, eps=None): def __init__(self, f, pt, ndarray_mask=None, eps=None):
"""Return the gradient of f at pt. """Return the gradient of f at pt.
...@@ -78,13 +79,12 @@ class multiple_outputs_numeric_grad: ...@@ -78,13 +79,12 @@ class multiple_outputs_numeric_grad:
if not ndarray_mask: if not ndarray_mask:
ndarray_mask = [True for x in pt] ndarray_mask = [True for x in pt]
dtype_eps = multiple_outputs_numeric_grad.type_eps['float64'] dtype_eps = type_eps['float64']
for i, p in enumerate(pt): for i, p in enumerate(pt):
if ndarray_mask[i]: if ndarray_mask[i]:
pt[i] = numpy.array(p) pt[i] = numpy.array(p)
_eps = multiple_outputs_numeric_grad.type_eps[str( _eps = type_eps[str(pt[i].dtype)]
pt[i].dtype)]
if _eps > dtype_eps: if _eps > dtype_eps:
dtype_eps = _eps dtype_eps = _eps
...@@ -3980,6 +3980,63 @@ class T_Scan(unittest.TestCase): ...@@ -3980,6 +3980,63 @@ class T_Scan(unittest.TestCase):
f = theano.function([W, n_steps], H) f = theano.function([W, n_steps], H)
f(numpy.ones((8,), dtype='float32'), 1) f(numpy.ones((8,), dtype='float32'), 1)
def test_strict_mode(self):
n = 10
w = numpy.array([[-1,2],[3,-4]]).astype(theano.config.floatX)
w_ = theano.shared(w)
x0 = numpy.array([1,2]).astype(theano.config.floatX)
x0_ = tensor.vector(name='x0', dtype=theano.config.floatX)
def _scan_loose(x):
return tensor.dot(x, w_)
def _scan_strict(x, w_ns):
return tensor.dot(x, w_ns)
ret_loose = theano.scan(_scan_loose,
sequences=[],
outputs_info=[x0_],
n_steps=n,
strict=False)
f_loose = theano.function([x0_], ret_loose[0][-1])
ret_strict = theano.scan(_scan_strict,
sequences=[],
outputs_info=[x0_],
non_sequences=[w_],
n_steps=n,
strict=True)
f_strict = theano.function([x0_], ret_strict[0][-1])
result_loose = f_loose(x0)
result_strict = f_strict(x0)
diff = (abs(result_loose - result_strict)).mean()
assert diff <= type_eps[theano.config.floatX]
@raises(theano.gof.fg.MissingInputError)
def test_strict_mode_ex(self):
n = 10
w = numpy.array([[-1,2],[3,-4]]).astype(theano.config.floatX)
w_ = theano.shared(w)
x0 = numpy.array([1,2]).astype(theano.config.floatX)
x0_ = tensor.vector(name='x0', dtype=theano.config.floatX)
def _scan_loose(x):
return tensor.dot(x, w_)
ret_strict = theano.scan(_scan_loose,
sequences=[],
outputs_info=[x0_],
n_steps=n,
strict=True)
f_strict = theano.function([x0_], ret_strict[0][-1])
result_strict = f_strict(x0)
def test_speed(): def test_speed():
# #
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论