提交 26429e59 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Hack around in pickle so that shared variables will load correctly (in an…

Hack around in pickle so that shared variables will load correctly (in an unoptimized graph) on a GPU-less machine.
上级 9c8cf380
...@@ -275,11 +275,27 @@ def handle_shared_float32(tf): ...@@ -275,11 +275,27 @@ def handle_shared_float32(tf):
""" """
if tf: if tf:
import theano.compile import theano.compile
import copy_reg
theano.compile.shared_constructor(float32_shared_constructor) theano.compile.shared_constructor(float32_shared_constructor)
# this is a bit of hackery to make the shared variables load
# with the proper type.
copy_reg.pickle(theano.gof.graph.Apply, reduce_apply,
load_shared_pickle)
else: else:
raise NotImplementedError('removing our handler') raise NotImplementedError('removing our handler')
def reduce_apply(apply):
if isinstance(apply.op, HostFromGpu) and len(apply.inputs) == 1 and \
isinstance(apply.inputs[0], CudaNdarraySharedVariable):
return load_shared_pickle, apply.inputs[0].get_value()
else:
# this will make protocol 2 a little bit less efficient
# but there is no way around it.
return apply.__reduce__()
def load_shared_pickle(val):
return theano.tensor.as_tensor_variable(theano.shared(val))
if config.device.startswith('gpu'): if config.device.startswith('gpu'):
use(device=config.device, force=config.force_device) use(device=config.device, force=config.force_device)
elif config.init_gpu_device: elif config.init_gpu_device:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论