提交 33247199 authored 作者: --global's avatar --global

Compute mappings in __init__ and __setstate__ if needed

上级 f52d6ee7
...@@ -229,6 +229,11 @@ class Scan(PureOp): ...@@ -229,6 +229,11 @@ class Scan(PureOp):
self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, []) self._cmodule_key = gof.CLinker().cmodule_key_(local_fgraph, [])
self._hash_inner_graph = hash(self._cmodule_key) self._hash_inner_graph = hash(self._cmodule_key)
# Compute mappings between outer inputs, outer outputs, inner
# inputs and inner outputs to determine with variables are associated
# with the same states.
self.var_mappings = self.get_oinp_iinp_iout_oout_mappings()
def validate_inner_graph(self): def validate_inner_graph(self):
""" Perform some elementary validations on the inner graph to ensure """ Perform some elementary validations on the inner graph to ensure
that it is coherent. that it is coherent.
...@@ -303,13 +308,19 @@ class Scan(PureOp): ...@@ -303,13 +308,19 @@ class Scan(PureOp):
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
self.validate_inner_graph()
if "allow_gc" not in self.__dict__: if "allow_gc" not in self.__dict__:
self.allow_gc = True self.allow_gc = True
self.info['allow_gc'] = True self.info['allow_gc'] = True
if not hasattr(self, 'gpua'): if not hasattr(self, 'gpua'):
self.gpua = False self.gpua = False
self.info['gpua'] = False self.info['gpua'] = False
if not hasattr(self, 'var_mappings'):
# Generate the mappings between inner and outer inputs and outputs
# if they haven't already been generated.
self.var_mappings = self.get_oinp_iinp_iout_oout_mappings()
# Ensure that the graph associated with the inner function is valid.
self.validate_inner_graph()
def make_node(self, *inputs): def make_node(self, *inputs):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论