提交 11e7d647 authored 作者: abergeron's avatar abergeron

Merge pull request #2803 from carriepl/scan_segfault

[CRASH] Scan segfault with inconsistent inner graphs
...@@ -179,8 +179,81 @@ class Scan(PureOp): ...@@ -179,8 +179,81 @@ class Scan(PureOp):
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, []) self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
def validate_inner_graph(self):
""" Perform some elementary validations on the inner graph to ensure
that it is coherent.
"""
# For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype
nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for outer_oidx in range(nb_recurr_outputs):
outer_iidx = outer_iidx_from_outer_oidx[outer_oidx]
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs,
inner_oidxs):
type_input = self.inputs[inner_iidx].type
type_output = self.outputs[inner_oidx].type
if (type_input != type_output):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : an input and an output are "
"associated with the same recurrent state "
"and should have the same type but have "
"type '%s' and '%s' respectively." %
(self.name, type_input, type_output))
# If scan has the flag 'gpu' set to false (meaning that is shouldn't
# use the CUDA gpu backend ), ensure that is has no input and no
# output with type CudaNdarrayType
from theano.sandbox.cuda import CudaNdarrayType
if not self.info.get("gpu", False):
for inp in self.inputs:
if isinstance(inp.type, CudaNdarrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the inputs to the "
"inner graph is of type CudaNdarray but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
for out in self.outputs:
if isinstance(out.type, CudaNdarrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the outputs to the "
"inner graph is of type CudaNdarray but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
# If scan has the flag 'gpua' set to false (meaning that is shouldn't
# use the gpuarray gpu backend ), ensure that is has no input and no
# output with type GpuArrayType
from theano.sandbox.gpuarray import GpuArrayType
if not self.info.get("gpua", False):
for inp in self.inputs:
if isinstance(inp.type, GpuArrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the inputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
for out in self.outputs:
if isinstance(inp.type, GpuArrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the outputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
self.validate_inner_graph()
if "allow_gc" not in self.__dict__: if "allow_gc" not in self.__dict__:
self.allow_gc = True self.allow_gc = True
self.info['allow_gc'] = True self.info['allow_gc'] = True
...@@ -554,6 +627,11 @@ class Scan(PureOp): ...@@ -554,6 +627,11 @@ class Scan(PureOp):
the thunk can potentially cache return values (like CLinker does), the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list. then it must not do so for variables in the no_recycling list.
""" """
# Before building the thunk, validate that the inner graph is
# coherent
self.validate_inner_graph()
# Setting up all my variables in what I believe is a more Cython # Setting up all my variables in what I believe is a more Cython
# friendly form # friendly form
...@@ -1408,37 +1486,6 @@ class Scan(PureOp): ...@@ -1408,37 +1486,6 @@ class Scan(PureOp):
if hasattr(node.tag, 'connection_pattern'): if hasattr(node.tag, 'connection_pattern'):
return node.tag.connection_pattern return node.tag.connection_pattern
# Define helper functions
def _get_inner_outs_idx(oidx):
"""Given the index of an outer output, return the indices of the
corresponding inner output(s) in a sequence.
"""
s = 0
e = 0
for p in xrange(oidx + 1):
s = e
if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return range(s, e)
def _get_inner_inps_idx(outer_iidx):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
"""
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
# For every inner input, if the corresponding outer input is the
# desired one, store the index
inner_iidxs = []
for i in xrange(len(outer_iidx_from_inner_iidx)):
if outer_iidx_from_inner_iidx[i] == outer_iidx:
inner_iidxs.append(i)
return inner_iidxs
# Obtain the connection pattern of the inner function. # Obtain the connection pattern of the inner function.
inner_connect_pattern = self.inner_connection_pattern() inner_connect_pattern = self.inner_connection_pattern()
...@@ -1451,10 +1498,10 @@ class Scan(PureOp): ...@@ -1451,10 +1498,10 @@ class Scan(PureOp):
# and inner outputs and, if one such pair of inner variables is # and inner outputs and, if one such pair of inner variables is
# connected than the pair of outer variables is connected. # connected than the pair of outer variables is connected.
for outer_oidx in range(len(node.outputs)): for outer_oidx in range(len(node.outputs)):
inner_oidxs = _get_inner_outs_idx(outer_oidx) inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
for outer_iidx in range(len(node.inputs)): for outer_iidx in range(len(node.inputs)):
inner_iidxs = _get_inner_inps_idx(outer_iidx) inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
for inner_oidx in inner_oidxs: for inner_oidx in inner_oidxs:
for inner_iidx in inner_iidxs: for inner_iidx in inner_iidxs:
...@@ -1490,6 +1537,36 @@ class Scan(PureOp): ...@@ -1490,6 +1537,36 @@ class Scan(PureOp):
node.tag.connection_pattern = connection_pattern node.tag.connection_pattern = connection_pattern
return connection_pattern return connection_pattern
def get_inner_oidx_from_outer_oidx(self, outer_oidx):
"""Given the index of an outer output, return the indices of the
corresponding inner output(s) in a sequence.
"""
s = 0
e = 0
for p in xrange(outer_oidx + 1):
s = e
if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return range(s, e)
def get_inner_iidx_from_outer_iidx(self, outer_oidx):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
"""
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
# For every inner input, if the corresponding outer input is the
# desired one, store the index
inner_iidxs = []
for i in xrange(len(outer_iidx_from_inner_iidx)):
if outer_iidx_from_inner_iidx[i] == outer_oidx:
inner_iidxs.append(i)
return inner_iidxs
def get_outer_iidx_from_outer_oidx_seq(self): def get_outer_iidx_from_outer_oidx_seq(self):
""" Return a sequence where the value at the i-th position is the """ Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th outer output index of the outer input corresponding to the i-th outer output
......
This source diff could not be displayed because it is too large. You can view the blob instead.
...@@ -1337,6 +1337,32 @@ class T_Scan(unittest.TestCase): ...@@ -1337,6 +1337,32 @@ class T_Scan(unittest.TestCase):
theano_v = my_f() theano_v = my_f()
utt.assert_allclose(theano_v, numpy_v[5:, :]) utt.assert_allclose(theano_v, numpy_v[5:, :])
def test_inconsistent_inner_fct(self):
# Test that scan can detect inconsistencies in the inner graph and
# raises an appropriate exception.
# This test has not been extensively tested for Python 3 so it should
# be skipped if python version is >=3
version = sys.version_info
if version >= (3,):
raise SkipTest("This test relies on a pickled file produced with "
"Python 2. The current python version "
"(%i.%i.%i.%i) is >= 3 so the test will be "
"skipped." % (version.major, version.minor,
version.micro, version.serial))
# The pickled scan op used in this test requires the use of a gpu
from theano.sandbox import cuda
if not cuda.cuda_available:
raise SkipTest('Optional package cuda disabled')
# When unpickled, the scan op should perform validation on its inner
# graph, detect the inconsistencies and raise a TypeError
folder = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(folder, "inconsistent_scan.pkl")
assert_raises(TypeError, cPickle.load, open(path, "r"))
def test_cuda_gibbs_chain(self): def test_cuda_gibbs_chain(self):
from theano.sandbox import cuda from theano.sandbox import cuda
if not cuda.cuda_available: if not cuda.cuda_available:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论