提交 11e7d647 authored 作者: abergeron's avatar abergeron

Merge pull request #2803 from carriepl/scan_segfault

[CRASH] Scan segfault with inconsistent inner graphs
......@@ -179,8 +179,81 @@ class Scan(PureOp):
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key)
def validate_inner_graph(self):
""" Perform some elementary validations on the inner graph to ensure
that it is coherent.
"""
# For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype
nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for outer_oidx in range(nb_recurr_outputs):
outer_iidx = outer_iidx_from_outer_oidx[outer_oidx]
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs,
inner_oidxs):
type_input = self.inputs[inner_iidx].type
type_output = self.outputs[inner_oidx].type
if (type_input != type_output):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : an input and an output are "
"associated with the same recurrent state "
"and should have the same type but have "
"type '%s' and '%s' respectively." %
(self.name, type_input, type_output))
# If scan has the flag 'gpu' set to false (meaning that is shouldn't
# use the CUDA gpu backend ), ensure that is has no input and no
# output with type CudaNdarrayType
from theano.sandbox.cuda import CudaNdarrayType
if not self.info.get("gpu", False):
for inp in self.inputs:
if isinstance(inp.type, CudaNdarrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the inputs to the "
"inner graph is of type CudaNdarray but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
for out in self.outputs:
if isinstance(out.type, CudaNdarrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the outputs to the "
"inner graph is of type CudaNdarray but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
# If scan has the flag 'gpua' set to false (meaning that is shouldn't
# use the gpuarray gpu backend ), ensure that is has no input and no
# output with type GpuArrayType
from theano.sandbox.gpuarray import GpuArrayType
if not self.info.get("gpua", False):
for inp in self.inputs:
if isinstance(inp.type, GpuArrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the inputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
for out in self.outputs:
if isinstance(inp.type, GpuArrayType):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : one of the outputs to the "
"inner graph is of type GpuArrayType but "
"the attributes of the scan op indicate "
"that it shouldn't be the case")
def __setstate__(self, d):
self.__dict__.update(d)
self.validate_inner_graph()
if "allow_gc" not in self.__dict__:
self.allow_gc = True
self.info['allow_gc'] = True
......@@ -554,6 +627,11 @@ class Scan(PureOp):
the thunk can potentially cache return values (like CLinker does),
then it must not do so for variables in the no_recycling list.
"""
# Before building the thunk, validate that the inner graph is
# coherent
self.validate_inner_graph()
# Setting up all my variables in what I believe is a more Cython
# friendly form
......@@ -1408,37 +1486,6 @@ class Scan(PureOp):
if hasattr(node.tag, 'connection_pattern'):
return node.tag.connection_pattern
# Define helper functions
def _get_inner_outs_idx(oidx):
"""Given the index of an outer output, return the indices of the
corresponding inner output(s) in a sequence.
"""
s = 0
e = 0
for p in xrange(oidx + 1):
s = e
if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return range(s, e)
def _get_inner_inps_idx(outer_iidx):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
"""
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
# For every inner input, if the corresponding outer input is the
# desired one, store the index
inner_iidxs = []
for i in xrange(len(outer_iidx_from_inner_iidx)):
if outer_iidx_from_inner_iidx[i] == outer_iidx:
inner_iidxs.append(i)
return inner_iidxs
# Obtain the connection pattern of the inner function.
inner_connect_pattern = self.inner_connection_pattern()
......@@ -1451,10 +1498,10 @@ class Scan(PureOp):
# and inner outputs and, if one such pair of inner variables is
# connected than the pair of outer variables is connected.
for outer_oidx in range(len(node.outputs)):
inner_oidxs = _get_inner_outs_idx(outer_oidx)
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
for outer_iidx in range(len(node.inputs)):
inner_iidxs = _get_inner_inps_idx(outer_iidx)
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
for inner_oidx in inner_oidxs:
for inner_iidx in inner_iidxs:
......@@ -1490,6 +1537,36 @@ class Scan(PureOp):
node.tag.connection_pattern = connection_pattern
return connection_pattern
def get_inner_oidx_from_outer_oidx(self, outer_oidx):
"""Given the index of an outer output, return the indices of the
corresponding inner output(s) in a sequence.
"""
s = 0
e = 0
for p in xrange(outer_oidx + 1):
s = e
if p < self.n_mit_mot:
e += len(self.mitmot_out_taps()[p])
else:
e += 1
return range(s, e)
def get_inner_iidx_from_outer_iidx(self, outer_oidx):
"""Given the index of an outer input, return the indices of the
corresponding inner input(s) in a sequence.
"""
outer_iidx_from_inner_iidx = self.get_outer_iidx_from_inner_iidx_seq()
# For every inner input, if the corresponding outer input is the
# desired one, store the index
inner_iidxs = []
for i in xrange(len(outer_iidx_from_inner_iidx)):
if outer_iidx_from_inner_iidx[i] == outer_oidx:
inner_iidxs.append(i)
return inner_iidxs
def get_outer_iidx_from_outer_oidx_seq(self):
""" Return a sequence where the value at the i-th position is the
index of the outer input corresponding to the i-th outer output
......
This source diff could not be displayed because it is too large. You can view the blob instead.
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论