提交 11e7d647 authored 作者: abergeron's avatar abergeron

Merge pull request #2803 from carriepl/scan_segfault

[CRASH] Scan segfault with inconsistent inner graphs
......@@ -179,8 +179,81 @@ class Scan(PureOp):
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key)
def validate_inner_graph(self):
""" Perform some elementary validations on the inner graph to ensure
that it is coherent.
"""
# For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype
nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for outer_oidx in range(nb_recurr_outputs):
outer_iidx = outer_iidx_from_outer_oidx[outer_oidx]
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs,
inner_oidxs):
type_input = self.inputs[inner_iidx].type
type_output = self.outputs[inner_oidx].type
if (type_input != type_output):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : an input and an output are "
"associated with the same recurrent state "
"and should have the same type but have "
"type '%s' and '%s' respectively." %
(self.name, type_input, type_output))
# If scan has the flag 'gpu' set to false (meaning that is shouldn't
# use the CUDA gpu backend ), ensure that is has no input and no
# output with type CudaNdarrayType
from theano.sandbox.cuda import CudaNdarrayType
if not self.info.get("gpu", False):
for inp in self.inputs:
if isinstance(inp.type, CudaNdarrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the inputs to the "
"inner graph is of type CudaNdarray but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
for out in self.outputs:
if isinstance(out.type, CudaNdarrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the outputs to the "
"inner graph is of type CudaNdarray but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
# If scan has the flag 'gpua' set to false (meaning that is shouldn't
# use the gpuarray gpu backend ), ensure that is has no input and no
# output with type GpuArrayType
from theano.sandbox.gpuarray import GpuArrayType
if not self.info.get("gpua", False):
for inp in self.inputs:
if isinstance(inp.type, GpuArrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the inputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
for out in self.outputs:
if isinstance(inp.type, GpuArrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the outputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
def __setstate__(self, d):
self.__dict__.update(d)
self.validate_inner_graph()
if "allow_gc" not in self.__dict__:
self.allow_gc = True
self.info['allow_gc'] = True
......@@ -554,6 +627,11 @@ class Scan(PureOp):
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
# Before building the thunk, validate that the inner graph is
# coherent
self.validate_inner_graph()
# Setting up all my variables in what I believe is a more Cython
# friendly form
......@@ -1408,37 +1486,6 @@ class Scan(PureOp):
if hasattr(node.tag, 'connection_pattern'):
return node.tag.connection_pattern
# Define helper functions
def _get_inner_outs_idx(oidx):
"""Given the index of an outer output, return the indices of the
corresponding inner output(s) in a sequence.
"""
s = 0
e = 0
for p in xrange(oidx + 1):
s = e
if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return range(s, e)
def _get_inner_inps_idx(outer_iidx):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
"""
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
# For every inner input, if the corresponding outer input is the
# desired one, store the index
inner_iidxs = []
for i in xrange(len(outer_iidx_from_inner_iidx)):
if outer_iidx_from_inner_iidx[i] == outer_iidx:
inner_iidxs.append(i)
return inner_iidxs
# Obtain the connection pattern of the inner function.
inner_connect_pattern = self.inner_connection_pattern()
......@@ -1451,10 +1498,10 @@ class Scan(PureOp):
# and inner outputs and, if one such pair of inner variables is
# connected than the pair of outer variables is connected.
for outer_oidx in range(len(node.outputs)):
inner_oidxs = _get_inner_outs_idx(outer_oidx)
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
for outer_iidx in range(len(node.inputs)):
inner_iidxs = _get_inner_inps_idx(outer_iidx)
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
for inner_oidx in inner_oidxs:
for inner_iidx in inner_iidxs:
......@@ -1490,6 +1537,36 @@ class Scan(PureOp):
node.tag.connection_pattern = connection_pattern
return connection_pattern
def get_inner_oidx_from_outer_oidx(self, outer_oidx):
"""Given the index of an outer output, return the indices of the
corresponding inner output(s) in a sequence.
"""
s = 0
e = 0
for p in xrange(outer_oidx + 1):
s = e
if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return range(s, e)
def get_inner_iidx_from_outer_iidx(self, outer_oidx):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
"""
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
# For every inner input, if the corresponding outer input is the
# desired one, store the index
inner_iidxs = []
for i in xrange(len(outer_iidx_from_inner_iidx)):
if outer_iidx_from_inner_iidx[i] == outer_oidx:
inner_iidxs.append(i)
return inner_iidxs
def get_outer_iidx_from_outer_oidx_seq(self):
""" Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th outer output
......
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -1337,6 +1337,32 @@ class T_Scan(unittest.TestCase):
theano_v = my_f()
utt.assert_allclose(theano_v, numpy_v[5:, :])
def test_inconsistent_inner_fct(self):
# Test that scan can detect inconsistencies in the inner graph and
# raises an appropriate exception.
# This test has not been extensively tested for Python 3 so it should
# be skipped if python version is >=3
version = sys.version_info
if version >= (3,):
raise SkipTest("This test relies on a pickled file produced with "
"Python 2. The current python version "
"(%i.%i.%i.%i) is >= 3 so the test will be "
"skipped." % (version.major, version.minor,
version.micro, version.serial))
# The pickled scan op used in this test requires the use of a gpu
from theano.sandbox import cuda
if not cuda.cuda_available:
raise SkipTest('Optional package cuda disabled')
# When unpickled, the scan op should perform validation on its inner
# graph, detect the inconsistencies and raise a TypeError
folder = os.path.dirname(os.path.abspath(__file__))
path = os.path.join(folder, "inconsistent_scan.pkl")
assert_raises(TypeError, cPickle.load, open(path, "r"))
def test_cuda_gibbs_chain(self):
from theano.sandbox import cuda
if not cuda.cuda_available:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论