提交 5ba45c50 authored 作者: Frederic Bastien's avatar Frederic Bastien

Allow to reload COp files after init.

上级 13d7543b
...@@ -1270,10 +1270,11 @@ class COp(Op): ...@@ -1270,10 +1270,11 @@ class COp(Op):
if not isinstance(func_files, list): if not isinstance(func_files, list):
func_files = [func_files] func_files = [func_files]
self.func_files = [self.get_path(f) for f in func_files]
self.func_name = func_name self.func_name = func_name
# Keep the original name. If we reload old pickle, we want to
self.load_c_code() # find the new path and new version of the file in Theano.
self.func_files = func_files
self.load_c_code(func_files)
if len(self.code_sections) == 0: if len(self.code_sections) == 0:
raise ValueError("No sections where defined in C files") raise ValueError("No sections where defined in C files")
...@@ -1288,12 +1289,13 @@ class COp(Op): ...@@ -1288,12 +1289,13 @@ class COp(Op):
raise ValueError('Cannot have an "op_code_cleanup" section ' raise ValueError('Cannot have an "op_code_cleanup" section '
'and specify the func_name') 'and specify the func_name')
def load_c_code(self): def load_c_code(self, func_files):
""" """
Loads the c code to perform the Op Loads the c code to perform the Op
""" """
func_files = [self.get_path(f) for f in func_files]
self.func_codes = [] self.func_codes = []
for func_file in self.func_files: for func_file in func_files:
with open(func_file, 'r') as f: with open(func_file, 'r') as f:
self.func_codes.append(f.read()) self.func_codes.append(f.read())
...@@ -1336,7 +1338,7 @@ class COp(Op): ...@@ -1336,7 +1338,7 @@ class COp(Op):
if split[0].strip() != '': if split[0].strip() != '':
raise ValueError('Stray code before first #section ' raise ValueError('Stray code before first #section '
'statement (in file %s): %s' % 'statement (in file %s): %s' %
(self.func_files[i], split[0])) (func_files[i], split[0]))
# Separate the code into the proper sections # Separate the code into the proper sections
n = 1 n = 1
...@@ -1344,7 +1346,7 @@ class COp(Op): ...@@ -1344,7 +1346,7 @@ class COp(Op):
if split[n] not in self.SECTIONS: if split[n] not in self.SECTIONS:
raise ValueError( raise ValueError(
"Unknown section type (in file %s): %s" % "Unknown section type (in file %s): %s" %
(self.func_files[i], split[n])) (func_files[i], split[n]))
if split[n] not in self.code_sections: if split[n] not in self.code_sections:
self.code_sections[split[n]] = "" self.code_sections[split[n]] = ""
self.code_sections[split[n]] += split[n + 1] self.code_sections[split[n]] += split[n + 1]
...@@ -1352,7 +1354,7 @@ class COp(Op): ...@@ -1352,7 +1354,7 @@ class COp(Op):
else: else:
raise ValueError("No valid section marker was found in file " raise ValueError("No valid section marker was found in file "
"%s" % self.func_files[i]) "%s" % func_files[i])
def get_op_params(self): def get_op_params(self):
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论