提交 fe15cbbb authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Replace Text with str and fix scope of variables in ExternalCOp

上级 ce1fdfdf
......@@ -13,7 +13,6 @@ from typing import (
Optional,
Pattern,
Set,
Text,
Tuple,
Union,
cast,
......@@ -230,7 +229,7 @@ int main( int argc, const char* argv[] )
self.update_self_openmp()
def lquote_macro(txt: Text) -> Text:
def lquote_macro(txt: str) -> str:
"""Turn the last line of text into a ``\\``-commented line."""
res = []
spl = txt.split("\n")
......@@ -240,7 +239,7 @@ def lquote_macro(txt: Text) -> Text:
return "\n".join(res)
def get_sub_macros(sub: Dict[Text, Text]) -> Union[Tuple[Text], Tuple[Text, Text]]:
def get_sub_macros(sub: Dict[str, str]) -> Union[Tuple[str], Tuple[str, str]]:
define_macros = []
undef_macros = []
define_macros.append(f"#define FAIL {lquote_macro(sub['fail'])}")
......@@ -253,8 +252,8 @@ def get_sub_macros(sub: Dict[Text, Text]) -> Union[Tuple[Text], Tuple[Text, Text
def get_io_macros(
inputs: List[Text], outputs: List[Text]
) -> Union[Tuple[List[Text]], Tuple[str, str]]:
inputs: List[str], outputs: List[str]
) -> Union[Tuple[List[str]], Tuple[str, str]]:
define_macros = []
undef_macros = []
......@@ -285,7 +284,7 @@ class ExternalCOp(COp):
r"^AESARA_(APPLY|SUPPORT)_CODE_SECTION$", re.MULTILINE
)
# This is the set of allowed markers
SECTIONS: ClassVar[Set[Text]] = {
SECTIONS: ClassVar[Set[str]] = {
"init_code",
"init_code_apply",
"init_code_struct",
......@@ -296,9 +295,11 @@ class ExternalCOp(COp):
"code",
"code_cleanup",
}
_cop_num_inputs: Optional[int] = None
_cop_num_outputs: Optional[int] = None
@classmethod
def get_path(cls, f: Text) -> Text:
def get_path(cls, f: str) -> str:
"""Convert a path relative to the location of the class file into an absolute path.
Paths that are already absolute are passed through unchanged.
......@@ -311,7 +312,7 @@ class ExternalCOp(COp):
return f
def __init__(
self, func_files: Union[Text, List[Text]], func_name: Optional[Text] = None
self, func_files: Union[str, List[str]], func_name: Optional[str] = None
):
"""
Sections are loaded from files in order with sections in later
......@@ -319,36 +320,37 @@ class ExternalCOp(COp):
"""
if not isinstance(func_files, list):
func_files = [func_files]
self.func_files = [func_files]
else:
self.func_files = func_files
self.func_name = func_name
self.func_codes: List[str] = []
# Keep the original name. If we reload old pickle, we want to
# find the new path and new version of the file in Aesara.
self.func_files = func_files
self.load_c_code(func_files)
self.func_name = func_name
self.code_sections: Dict[str, str] = dict()
self.load_c_code(self.func_files)
if len(self.code_sections) == 0:
raise ValueError("No sections where defined in C files")
raise ValueError("No sections where defined in the C files")
if self.func_name is not None:
if "op_code" in self.code_sections:
# maybe a warning instead (and clearing the key)
raise ValueError(
'Cannot have an "op_code" section and ' "specify the func_name"
"Cannot have an `op_code` section and specify `func_name`"
)
if "op_code_cleanup" in self.code_sections:
# maybe a warning instead (and clearing the key)
raise ValueError(
'Cannot have an "op_code_cleanup" section '
"and specify the func_name"
"Cannot have an `op_code_cleanup` section and specify `func_name`"
)
def load_c_code(self, func_files: List[Text]) -> None:
def load_c_code(self, func_files: List[str]) -> None:
"""Loads the C code to perform the `Op`."""
func_files = [self.get_path(f) for f in func_files]
self.func_codes = []
for func_file in func_files:
# U (universal) will convert all new lines format to \n.
with open(func_file) as f:
self.func_codes.append(f.read())
......@@ -370,7 +372,6 @@ class ExternalCOp(COp):
"be used at the same time."
)
self.code_sections = dict()
for i, code in enumerate(self.func_codes):
if self.backward_re.search(code):
# This is backward compat code that will go away in a while
......@@ -502,24 +503,35 @@ class ExternalCOp(COp):
else:
return super().c_cleanup_code_struct(node, name)
def format_c_function_args(self, inp: List[Text], out: List[Text]) -> Text:
def format_c_function_args(self, inp: List[str], out: List[str]) -> str:
"""Generate a string containing the arguments sent to the external C function.
The result will have the format: ``"input0, input1, input2, &output0, &output1"``.
"""
inp = list(inp)
numi = getattr(self, "_cop_num_inputs", len(inp))
if self._cop_num_inputs is not None:
numi = self._cop_num_inputs
else:
numi = len(inp)
while len(inp) < numi:
inp.append("NULL")
out = [f"&{o}" for o in out]
numo = getattr(self, "_cop_num_outputs", len(out))
if self._cop_num_outputs is not None:
numo = self._cop_num_outputs
else:
numo = len(out)
while len(out) < numo:
out.append("NULL")
return ", ".join(inp + out)
def get_c_macros(
self, node: Apply, name: Text, check_input: Optional[bool] = None
self, node: Apply, name: str, check_input: Optional[bool] = None
) -> Union[Tuple[str], Tuple[str, str]]:
"Construct a pair of C ``#define`` and ``#undef`` code strings."
define_template = "#define %s %s"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论