提交 fe15cbbb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace Text with str and fix scope of variables in ExternalCOp

上级 ce1fdfdf
...@@ -13,7 +13,6 @@ from typing import ( ...@@ -13,7 +13,6 @@ from typing import (
Optional, Optional,
Pattern, Pattern,
Set, Set,
Text,
Tuple, Tuple,
Union, Union,
cast, cast,
...@@ -230,7 +229,7 @@ int main( int argc, const char* argv[] ) ...@@ -230,7 +229,7 @@ int main( int argc, const char* argv[] )
self.update_self_openmp() self.update_self_openmp()
def lquote_macro(txt: Text) -> Text: def lquote_macro(txt: str) -> str:
"""Turn the last line of text into a ``\\``-commented line.""" """Turn the last line of text into a ``\\``-commented line."""
res = [] res = []
spl = txt.split("\n") spl = txt.split("\n")
...@@ -240,7 +239,7 @@ def lquote_macro(txt: Text) -> Text: ...@@ -240,7 +239,7 @@ def lquote_macro(txt: Text) -> Text:
return "\n".join(res) return "\n".join(res)
def get_sub_macros(sub: Dict[Text, Text]) -> Union[Tuple[Text], Tuple[Text, Text]]: def get_sub_macros(sub: Dict[str, str]) -> Union[Tuple[str], Tuple[str, str]]:
define_macros = [] define_macros = []
undef_macros = [] undef_macros = []
define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}") define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}")
...@@ -253,8 +252,8 @@ def get_sub_macros(sub: Dict[Text, Text]) -> Union[Tuple[Text], Tuple[Text, Text ...@@ -253,8 +252,8 @@ def get_sub_macros(sub: Dict[Text, Text]) -> Union[Tuple[Text], Tuple[Text, Text
def get_io_macros( def get_io_macros(
inputs: List[Text], outputs: List[Text] inputs: List[str], outputs: List[str]
) -> Union[Tuple[List[Text]], Tuple[str, str]]: ) -> Union[Tuple[List[str]], Tuple[str, str]]:
define_macros = [] define_macros = []
undef_macros = [] undef_macros = []
...@@ -285,7 +284,7 @@ class ExternalCOp(COp): ...@@ -285,7 +284,7 @@ class ExternalCOp(COp):
r"^AESARA_(APPLY|SUPPORT)_CODE_SECTION$", re.MULTILINE r"^AESARA_(APPLY|SUPPORT)_CODE_SECTION$", re.MULTILINE
) )
# This is the set of allowed markers # This is the set of allowed markers
SECTIONS: ClassVar[Set[Text]] = { SECTIONS: ClassVar[Set[str]] = {
"init_code", "init_code",
"init_code_apply", "init_code_apply",
"init_code_struct", "init_code_struct",
...@@ -296,9 +295,11 @@ class ExternalCOp(COp): ...@@ -296,9 +295,11 @@ class ExternalCOp(COp):
"code", "code",
"code_cleanup", "code_cleanup",
} }
_cop_num_inputs: Optional[int] = None
_cop_num_outputs: Optional[int] = None
@classmethod @classmethod
def get_path(cls, f: Text) -> Text: def get_path(cls, f: str) -> str:
"""Convert a path relative to the location of the class file into an absolute path. """Convert a path relative to the location of the class file into an absolute path.
Paths that are already absolute are passed through unchanged. Paths that are already absolute are passed through unchanged.
...@@ -311,7 +312,7 @@ class ExternalCOp(COp): ...@@ -311,7 +312,7 @@ class ExternalCOp(COp):
return f return f
def __init__( def __init__(
self, func_files: Union[Text, List[Text]], func_name: Optional[Text] = None self, func_files: Union[str, List[str]], func_name: Optional[str] = None
): ):
""" """
Sections are loaded from files in order with sections in later Sections are loaded from files in order with sections in later
...@@ -319,36 +320,37 @@ class ExternalCOp(COp): ...@@ -319,36 +320,37 @@ class ExternalCOp(COp):
""" """
if not isinstance(func_files, list): if not isinstance(func_files, list):
func_files = [func_files] self.func_files = [func_files]
else:
self.func_files = func_files
self.func_name = func_name self.func_codes: List[str] = []
# Keep the original name. If we reload old pickle, we want to # Keep the original name. If we reload old pickle, we want to
# find the new path and new version of the file in Aesara. # find the new path and new version of the file in Aesara.
self.func_files = func_files self.func_name = func_name
self.load_c_code(func_files) self.code_sections: Dict[str, str] = dict()
self.load_c_code(self.func_files)
if len(self.code_sections) == 0: if len(self.code_sections) == 0:
raise ValueError("No sections where defined in C files") raise ValueError("No sections where defined in the C files")
if self.func_name is not None: if self.func_name is not None:
if "op_code" in self.code_sections: if "op_code" in self.code_sections:
# maybe a warning instead (and clearing the key) # maybe a warning instead (and clearing the key)
raise ValueError( raise ValueError(
'Cannot have an "op_code" section and ' "specify the func_name" "Cannot have an `op_code` section and specify `func_name`"
) )
if "op_code_cleanup" in self.code_sections: if "op_code_cleanup" in self.code_sections:
# maybe a warning instead (and clearing the key) # maybe a warning instead (and clearing the key)
raise ValueError( raise ValueError(
'Cannot have an "op_code_cleanup" section ' "Cannot have an `op_code_cleanup` section and specify `func_name`"
"and specify the func_name"
) )
def load_c_code(self, func_files: List[Text]) -> None: def load_c_code(self, func_files: List[str]) -> None:
"""Loads the C code to perform the `Op`.""" """Loads the C code to perform the `Op`."""
func_files = [self.get_path(f) for f in func_files] func_files = [self.get_path(f) for f in func_files]
self.func_codes = []
for func_file in func_files: for func_file in func_files:
# U (universal) will convert all new lines format to \n.
with open(func_file) as f: with open(func_file) as f:
self.func_codes.append(f.read()) self.func_codes.append(f.read())
...@@ -370,7 +372,6 @@ class ExternalCOp(COp): ...@@ -370,7 +372,6 @@ class ExternalCOp(COp):
"be used at the same time." "be used at the same time."
) )
self.code_sections = dict()
for i, code in enumerate(self.func_codes): for i, code in enumerate(self.func_codes):
if self.backward_re.search(code): if self.backward_re.search(code):
# This is backward compat code that will go away in a while # This is backward compat code that will go away in a while
...@@ -502,24 +503,35 @@ class ExternalCOp(COp): ...@@ -502,24 +503,35 @@ class ExternalCOp(COp):
else: else:
return super().c_cleanup_code_struct(node, name) return super().c_cleanup_code_struct(node, name)
def format_c_function_args(self, inp: List[Text], out: List[Text]) -> Text: def format_c_function_args(self, inp: List[str], out: List[str]) -> str:
"""Generate a string containing the arguments sent to the external C function. """Generate a string containing the arguments sent to the external C function.
The result will have the format: ``"input0, input1, input2, &output0, &output1"``. The result will have the format: ``"input0, input1, input2, &output0, &output1"``.
""" """
inp = list(inp) inp = list(inp)
numi = getattr(self, "_cop_num_inputs", len(inp)) if self._cop_num_inputs is not None:
numi = self._cop_num_inputs
else:
numi = len(inp)
while len(inp) < numi: while len(inp) < numi:
inp.append("NULL") inp.append("NULL")
out = [f"&{o}" for o in out] out = [f"&{o}" for o in out]
numo = getattr(self, "_cop_num_outputs", len(out))
if self._cop_num_outputs is not None:
numo = self._cop_num_outputs
else:
numo = len(out)
while len(out) < numo: while len(out) < numo:
out.append("NULL") out.append("NULL")
return ", ".join(inp + out) return ", ".join(inp + out)
def get_c_macros( def get_c_macros(
self, node: Apply, name: Text, check_input: Optional[bool] = None self, node: Apply, name: str, check_input: Optional[bool] = None
) -> Union[Tuple[str], Tuple[str, str]]: ) -> Union[Tuple[str], Tuple[str, str]]:
"Construct a pair of C ``#define`` and ``#undef`` code strings." "Construct a pair of C ``#define`` and ``#undef`` code strings."
define_template = "#define %s %s" define_template = "#define %s %s"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论