提交 cdac0c69 authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5265 from nouiz/fix_reload_pickle

Fix reload pickle
......@@ -189,7 +189,7 @@ static struct PyModuleDef moduledef = {{
def add_support_code(self, code):
assert not self.finalized
if code not in self.support_code: # TODO: KLUDGE
if code and code not in self.support_code: # TODO: KLUDGE
self.support_code.append(code)
def add_function(self, fn):
......
......@@ -1270,10 +1270,11 @@ class COp(Op):
if not isinstance(func_files, list):
func_files = [func_files]
self.func_files = [self.get_path(f) for f in func_files]
self.func_name = func_name
self.load_c_code()
# Keep the original name. If we reload old pickle, we want to
# find the new path and new version of the file in Theano.
self.func_files = func_files
self.load_c_code(func_files)
if len(self.code_sections) == 0:
raise ValueError("No sections where defined in C files")
......@@ -1288,12 +1289,13 @@ class COp(Op):
raise ValueError('Cannot have an "op_code_cleanup" section '
'and specify the func_name')
def load_c_code(self):
def load_c_code(self, func_files):
"""
Loads the c code to perform the Op
"""
func_files = [self.get_path(f) for f in func_files]
self.func_codes = []
for func_file in self.func_files:
for func_file in func_files:
with open(func_file, 'r') as f:
self.func_codes.append(f.read())
......@@ -1336,7 +1338,7 @@ class COp(Op):
if split[0].strip() != '':
raise ValueError('Stray code before first #section '
'statement (in file %s): %s' %
(self.func_files[i], split[0]))
(func_files[i], split[0]))
# Separate the code into the proper sections
n = 1
......@@ -1344,7 +1346,7 @@ class COp(Op):
if split[n] not in self.SECTIONS:
raise ValueError(
"Unknown section type (in file %s): %s" %
(self.func_files[i], split[n]))
(func_files[i], split[n]))
if split[n] not in self.code_sections:
self.code_sections[split[n]] = ""
self.code_sections[split[n]] += split[n + 1]
......@@ -1352,7 +1354,7 @@ class COp(Op):
else:
raise ValueError("No valid section marker was found in file "
"%s" % self.func_files[i])
"%s" % func_files[i])
def get_op_params(self):
"""
......
......@@ -772,6 +772,15 @@ if (py_%(name)s == NULL) { %(freefunc)s(%(name)s); }
def __str__(self):
return "%s{%s}" % (self.__class__.__name__, self.ctype)
def __setstate__(self, dct):
self.__dict__.update(dct)
if not hasattr(self, 'headers'):
self.headers = ()
self.header_dirs = ()
self.libraries = ()
self.lib_dirs = ()
self.extra_support_code = ""
class CDataTypeConstant(graph.Constant):
def merge_signature(self):
......
......@@ -354,26 +354,6 @@ class GpuDnnConv(DnnBase, COp):
if self.inplace:
self.destroy_map = {0: [2]}
# In cuDNN version older than V3, the FFT implementation and the
# option to time the different implementations to get the fastest
# are both unavailable.
if version() < (3000, 3000):
if self.algo == 'fft':
raise RuntimeError("cuDNN FFT convolution requires cuDNN v3")
elif self.algo in ['guess_once', 'guess_on_shape_change']:
raise RuntimeError("cuDNN selection of convolution "
"implementation based on heuristics "
"requires cuDNN v3")
elif self.algo in ['time_once', 'time_on_shape_change']:
raise RuntimeError("cuDNN convolution timing requires cuDNN "
"v3")
# The fft_tiling implementation is only available from cuDNN V4 onward
if version() < (4000, 4000):
if self.algo == 'fft_tiling':
raise RuntimeError("cuDNN tiled-FFT convolution requires "
"cuDNN v4 or more recent")
if version() < (5000, 5000):
if self.algo == 'winograd':
raise RuntimeError("cuDNN winograd convolution requires "
......@@ -392,6 +372,9 @@ class GpuDnnConv(DnnBase, COp):
self.algo = config.dnn.conv.algo_fwd
if not hasattr(self, 'inplace'):
self.inplace = False
# Work around to reload old pickle.
# We need to find the new file name and reload c code.
self.load_c_code(["dnn_base.c", "dnn_conv_base.c", "dnn_fwd.c"])
def get_op_params(self):
if self.inplace:
......@@ -656,6 +639,7 @@ class GpuDnnConvGradW(DnnBase, COp):
self.algo = config.dnn.conv.algo_bwd_filter
if not hasattr(self, 'inplace'):
self.inplace = False
self.load_c_code(["dnn_base.c", "dnn_conv_base.c", "dnn_gw.c"])
def grad(self, inp, grads):
img, top, output, desc, alpha, beta = inp
......@@ -865,13 +849,6 @@ class GpuDnnConvGradI(DnnBase, COp):
if self.inplace:
self.destroy_map = {0: [2]}
# The small-workspace implementation is only available from cuDNN V4
# onward.
if version() < (4000, 4000):
if self.algo == 'fft_tiling':
raise RuntimeError("cuDNN's tiled-FFT convolution requires "
"cuDNN v4 or more recent")
if version() < (5000, 5000):
if self.algo == 'winograd':
raise RuntimeError("cuDNN's winograd convolution requires "
......@@ -890,6 +867,7 @@ class GpuDnnConvGradI(DnnBase, COp):
self.algo = config.dnn.conv.algo_bwd_data
if not hasattr(self, 'inplace'):
self.inplace = False
self.load_c_code(["dnn_base.c", "dnn_conv_base.c", "dnn_gi.c"])
def grad(self, inp, grads):
kerns, top, output, desc, alpha, beta = inp
......@@ -1465,13 +1443,6 @@ class GpuDnnPoolDesc(GpuOp):
self.stride = stride
self.pad = pad
if self.get_ndim() == 3 and version() < (3000, 3000):
raise RuntimeError("cuDNN 3d pooling requires cuDNN v3")
if (mode == 'average_exc_pad' and max(pad) > 0 and
version() < (4004, 4004)):
raise RuntimeError(
"cuDNN pooling mode 'average_exc_pad' requires at least v4")
def get_ndim(self):
return len(self.ws)
......@@ -2110,9 +2081,6 @@ class GpuDnnSoftmaxBase(DnnBase):
DnnBase.__init__(self)
self.tensor_format = tensor_format
if algo == 'log' and version() < (3000, 3000):
raise RuntimeError("cuDNN log-softmax requires cuDNN v3")
assert(algo in ('fast', 'accurate', 'log'))
self.algo = algo
......@@ -3179,7 +3147,7 @@ if True:
@local_optimizer([GpuElemwise, LogSoftmax])
def local_log_softmax_dnn(node):
# The log-softmax implementation is only available starting at cuDNN V3
if not dnn_available() or version() < (3000, 3000):
if not dnn_available():
return
if (isinstance(node.op, GpuElemwise) and
......
......@@ -3996,6 +3996,7 @@ class Composite(ScalarOp):
def __getstate__(self):
rval = dict(self.__dict__)
rval.pop('_impls', None)
rval.pop('prepare_node_called', None)
del rval['fgraph']
return rval
......@@ -4003,6 +4004,7 @@ class Composite(ScalarOp):
self.__dict__.update(d)
# We must call init to set fgraph and _impls again, as otherwise
# self.perform will not work.
self.prepare_node_called = set()
self.init_fgraph()
self.init_py_impls()
assert self._c_code
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论