提交 fa0007df authored 作者: carriepl's avatar carriepl

Compute [inps,outs]_on_gpu in __setstate__, if needed

上级 4b14b3c0
...@@ -314,13 +314,24 @@ class Scan(PureOp): ...@@ -314,13 +314,24 @@ class Scan(PureOp):
# Generate the mappings between inner and outer inputs and outputs # Generate the mappings between inner and outer inputs and outputs
# if they haven't already been generated. # if they haven't already been generated.
self.var_mappings = self.get_oinp_iinp_iout_oout_mappings() self.var_mappings = self.get_oinp_iinp_iout_oout_mappings()
if (hasattr(self, 'fn') and if hasattr(self, 'fn'):
not hasattr(self, 'thunk_mit_mot_out_slices')): if not hasattr(self, 'thunk_mit_mot_out_slices'):
# The thunk has been compiled before mit_mot preallocation feature # The thunk has been compiled before mit_mot preallocation
# was implemented. Mark every mit_mot output tap as not having # feature was implemented. Mark every mit_mot output tap as
# been preallocated # not having been preallocated
self.mitmots_preallocated = [False] * self.n_mit_mot_outs self.mitmots_preallocated = [False] * self.n_mit_mot_outs
if not hasattr(self, 'outs_on_gpu'):
# The thunk has been compiled before the analysis, at
# compilation time, of the location of the inputs and outputs.
# Perform this analysis here.
self.inps_on_gpu = [not isinstance(out,
theano.tensor.TensorVariable)
for out in self.fn.maker.fgraph.inputs]
self.outs_on_gpu = [not isinstance(out,
theano.tensor.TensorVariable)
for out in self.fn.maker.fgraph.outputs]
# Ensure that the graph associated with the inner function is valid. # Ensure that the graph associated with the inner function is valid.
self.validate_inner_graph() self.validate_inner_graph()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论