提交 8869aeb1 authored 作者: --global's avatar --global

Define validation function for the inner graph

上级 5e11d066
......@@ -179,6 +179,36 @@ class Scan(PureOp):
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key)
def validate_inner_graph(self):
""" Perform some elementary validations on the inner graph to ensure
that it is coherent.
"""
# For every recurrent output, iterate over the associated inner
# inputs and output and ensure that they have the same dtype
nb_recurr_outputs = self.n_mit_mot + self.n_mit_sot + self.n_sit_sot
outer_iidx_from_outer_oidx = self.get_outer_iidx_from_outer_oidx_seq()
for outer_oidx in range(nb_recurr_outputs):
outer_iidx = outer_iidx_from_outer_oidx[outer_oidx]
inner_iidxs = self.get_inner_iidx_from_outer_iidx(outer_iidx)
inner_oidxs = self.get_inner_oidx_from_outer_oidx(outer_oidx)
for (inner_iidx, inner_oidx) in itertools.product(inner_iidxs,
inner_oidxs):
type_input = self.inputs[inner_iidx].type
type_output = self.outputs[inner_oidx].type
if (type_input != type_output):
raise TypeError("Inconsistency in the inner graph of "
"scan '%s' : an input and an output are "
"associated with the same recurrent state "
"and should have the same type but have "
"type '%s' and '%s' respectively." %
(self.name, type_input, type_output))
def __setstate__(self, d):
self.__dict__.update(d)
if "allow_gc" not in self.__dict__:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论