提交 55b2f4fa authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Replace str.format with f-strings

上级 e2f9cb8e
...@@ -434,7 +434,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -434,7 +434,7 @@ class OpFromGraph(Op, HasInnerGraph):
def __str__(self): def __str__(self):
name = self.__class__.__name__ if self.name is None else self.name name = self.__class__.__name__ if self.name is None else self.name
is_inline = self.is_inline is_inline = self.is_inline
return "{name}{{inline={is_inline}}}".format(**locals()) return f"{name}{{inline={is_inline}}}"
def _combine_list_overrides(self, default_outs, custom_outs, callable_args): def _combine_list_overrides(self, default_outs, custom_outs, callable_args):
"""Combines default and custom overrides into a single list of outputs.""" """Combines default and custom overrides into a single list of outputs."""
......
...@@ -1890,11 +1890,7 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs): ...@@ -1890,11 +1890,7 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs):
) )
else: else:
if n_unnamed_inputs == 0: if n_unnamed_inputs == 0:
msg = "The function has {} named input{} ({}).".format( msg = f"The function has {n_named_inputs} named input{get_plural(n_named_inputs)} ({', '.join(named_inputs)})."
n_named_inputs,
get_plural(n_named_inputs),
", ".join(named_inputs),
)
else: else:
msg = ( msg = (
f"The function has {n_named_inputs} named input{get_plural(n_named_inputs)} ({', '.join(named_inputs)}), and {n_unnamed_inputs} unnamed " f"The function has {n_named_inputs} named input{get_plural(n_named_inputs)} ({', '.join(named_inputs)}), and {n_unnamed_inputs} unnamed "
......
...@@ -1664,15 +1664,12 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1664,15 +1664,12 @@ class PatternNodeRewriter(NodeRewriter):
def pattern_to_str(pattern): def pattern_to_str(pattern):
if isinstance(pattern, list | tuple): if isinstance(pattern, list | tuple):
return "{}({})".format( args = ", ".join(pattern_to_str(p) for p in pattern[1:])
str(pattern[0]), return f"{pattern[0]!s}({args})"
", ".join(pattern_to_str(p) for p in pattern[1:]),
)
elif isinstance(pattern, dict): elif isinstance(pattern, dict):
return "{} subject to {}".format( a = pattern_to_str(pattern["pattern"])
pattern_to_str(pattern["pattern"]), b = pattern.get("constraint", "no conditions")
str(pattern.get("constraint", "no conditions")), return f"{a} subject to {b}"
)
else: else:
return str(pattern) return str(pattern)
......
...@@ -245,10 +245,9 @@ class MetaType(ABCMeta): ...@@ -245,10 +245,9 @@ class MetaType(ABCMeta):
else: else:
def __str__(self): def __str__(self):
return "{}{{{}}}".format( classname = self.__class__.__name__
self.__class__.__name__, args = ", ".join(f"{p}={getattr(self, p)!r}" for p in props)
", ".join(f"{p}={getattr(self, p)!r}" for p in props), return f"{classname}{{{args}}}"
)
dct["__str__"] = __str__ dct["__str__"] = __str__
......
...@@ -219,7 +219,6 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -219,7 +219,6 @@ def struct_gen(args, struct_builders, blocks, sub):
""" """
struct_decl = "" struct_decl = ""
struct_init_head = "" struct_init_head = ""
struct_init_tail = ""
struct_cleanup = "" struct_cleanup = ""
for block in struct_builders: for block in struct_builders:
...@@ -243,7 +242,6 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -243,7 +242,6 @@ def struct_gen(args, struct_builders, blocks, sub):
# decrements the storage's refcount in the destructor # decrements the storage's refcount in the destructor
storage_decref = "\n".join(f"Py_XDECREF(this->{arg});" for arg in args) storage_decref = "\n".join(f"Py_XDECREF(this->{arg});" for arg in args)
args_names = ", ".join(args)
args_decl = ", ".join(f"PyObject* {arg}" for arg in args) args_decl = ", ".join(f"PyObject* {arg}" for arg in args)
# The following code stores the exception data in __ERROR, which # The following code stores the exception data in __ERROR, which
...@@ -251,7 +249,8 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -251,7 +249,8 @@ def struct_gen(args, struct_builders, blocks, sub):
# that holds the type, the value and the traceback. After storing # that holds the type, the value and the traceback. After storing
# the error, we return the failure code so we know which code # the error, we return the failure code so we know which code
# block failed. # block failed.
do_return = """ failure_var = sub["failure_var"]
do_return = f"""
if ({failure_var}) {{ if ({failure_var}) {{
// When there is a failure, this code puts the exception // When there is a failure, this code puts the exception
// in __ERROR. // in __ERROR.
...@@ -274,15 +273,13 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -274,15 +273,13 @@ def struct_gen(args, struct_builders, blocks, sub):
}} }}
// The failure code is returned to index what code block failed. // The failure code is returned to index what code block failed.
return {failure_var}; return {failure_var};
""".format(**sub) """
sub = dict(sub)
sub.update(locals())
# TODO: add some error checking to make sure storage_<x> are # TODO: add some error checking to make sure storage_<x> are
# 1-element lists and __ERROR is a 3-elements list. # 1-element lists and __ERROR is a 3-elements list.
struct_code = """ name = sub["name"]
struct_code = f"""
namespace {{ namespace {{
struct {name} {{ struct {name} {{
PyObject* __ERROR; PyObject* __ERROR;
...@@ -326,7 +323,7 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -326,7 +323,7 @@ def struct_gen(args, struct_builders, blocks, sub):
}} }}
}}; }};
}} }}
""".format(**sub) """
return struct_code return struct_code
...@@ -370,13 +367,10 @@ def get_c_init(fgraph, r, name, sub): ...@@ -370,13 +367,10 @@ def get_c_init(fgraph, r, name, sub):
Wrapper around c_init that initializes py_name to Py_None. Wrapper around c_init that initializes py_name to Py_None.
""" """
pre = ( pre = f"""
""
"""
py_{name} = Py_None; py_{name} = Py_None;
{{Py_XINCREF(py_{name});}} {{Py_XINCREF(py_{name});}}
""".format(**locals()) """
)
return pre + r.type.c_init(name, sub) return pre + r.type.c_init(name, sub)
...@@ -410,10 +404,10 @@ def get_c_extract(fgraph, r, name, sub): ...@@ -410,10 +404,10 @@ def get_c_extract(fgraph, r, name, sub):
else: else:
c_extract = r.type.c_extract(name, sub, False) c_extract = r.type.c_extract(name, sub, False)
pre = """ pre = f"""
py_{name} = PyList_GET_ITEM(storage_{name}, 0); py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}} {{Py_XINCREF(py_{name});}}
""".format(**locals()) """
return pre + c_extract return pre + c_extract
...@@ -439,10 +433,10 @@ def get_c_extract_out(fgraph, r, name, sub): ...@@ -439,10 +433,10 @@ def get_c_extract_out(fgraph, r, name, sub):
else: else:
c_extract = r.type.c_extract_out(name, sub, check_input, check_broadcast=False) c_extract = r.type.c_extract_out(name, sub, check_input, check_broadcast=False)
pre = """ pre = f"""
py_{name} = PyList_GET_ITEM(storage_{name}, 0); py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}} {{Py_XINCREF(py_{name});}}
""".format(**locals()) """
return pre + c_extract return pre + c_extract
...@@ -451,9 +445,9 @@ def get_c_cleanup(fgraph, r, name, sub): ...@@ -451,9 +445,9 @@ def get_c_cleanup(fgraph, r, name, sub):
Wrapper around c_cleanup that decrefs py_name. Wrapper around c_cleanup that decrefs py_name.
""" """
post = """ post = f"""
{{Py_XDECREF(py_{name});}} {{Py_XDECREF(py_{name});}}
""".format(**locals()) """
return r.type.c_cleanup(name, sub) + post return r.type.c_cleanup(name, sub) + post
...@@ -462,7 +456,9 @@ def get_c_sync(fgraph, r, name, sub): ...@@ -462,7 +456,9 @@ def get_c_sync(fgraph, r, name, sub):
Wrapper around c_sync that syncs py_name with storage. Wrapper around c_sync that syncs py_name with storage.
""" """
return """ failure_var = sub["failure_var"]
sync = r.type.c_sync(name, sub)
return f"""
if (!{failure_var}) {{ if (!{failure_var}) {{
{sync} {sync}
PyObject* old = PyList_GET_ITEM(storage_{name}, 0); PyObject* old = PyList_GET_ITEM(storage_{name}, 0);
...@@ -470,7 +466,7 @@ def get_c_sync(fgraph, r, name, sub): ...@@ -470,7 +466,7 @@ def get_c_sync(fgraph, r, name, sub):
PyList_SET_ITEM(storage_{name}, 0, py_{name}); PyList_SET_ITEM(storage_{name}, 0, py_{name});
{{Py_XDECREF(old);}} {{Py_XDECREF(old);}}
}} }}
""".format(**dict(sync=r.type.c_sync(name, sub), name=name, **sub)) """
def apply_policy(fgraph, policy, r, name, sub): def apply_policy(fgraph, policy, r, name, sub):
......
...@@ -1966,14 +1966,14 @@ class Compiler: ...@@ -1966,14 +1966,14 @@ class Compiler:
return False return False
code = ( code = (
""" f"""
{preamble} {preamble}
int main(int argc, char** argv) int main(int argc, char** argv)
{{ {{
{body} {body}
return 0; return 0;
}} }}
""".format(**locals()) """
).encode() ).encode()
return cls._try_compile_tmp( return cls._try_compile_tmp(
code, code,
......
...@@ -558,7 +558,9 @@ class CLinkerType(CLinkerObject): ...@@ -558,7 +558,9 @@ class CLinkerType(CLinkerObject):
uninitialized. uninitialized.
""" """
return """ c_init_code = self.c_init(name, sub)
c_extract_code = self.c_extract(name, sub, check_input)
return f"""
if (py_{name} == Py_None) if (py_{name} == Py_None)
{{ {{
{c_init_code} {c_init_code}
...@@ -567,13 +569,7 @@ class CLinkerType(CLinkerObject): ...@@ -567,13 +569,7 @@ class CLinkerType(CLinkerObject):
{{ {{
{c_extract_code} {c_extract_code}
}} }}
""".format( """
**dict(
name=name,
c_init_code=self.c_init(name, sub),
c_extract_code=self.c_extract(name, sub, check_input),
)
)
def c_cleanup(self, name: str, sub: dict[str, str]) -> str: def c_cleanup(self, name: str, sub: dict[str, str]) -> str:
"""Return C code to clean up after :meth:`CLinkerType.c_extract`. """Return C code to clean up after :meth:`CLinkerType.c_extract`.
......
...@@ -587,24 +587,15 @@ class ExternalCOp(COp): ...@@ -587,24 +587,15 @@ class ExternalCOp(COp):
params = f", {sub['params']}" params = f", {sub['params']}"
# Generate the C code # Generate the C code
return """ return f"""
{define_macros} {define_macros}
{{ {{
if ({func_name}({func_args}{params}) != 0) {{ if ({self.func_name}({self.format_c_function_args(inp, out)}{params}) != 0) {{
{fail} {sub['fail']}
}} }}
}} }}
{undef_macros} {undef_macros}
""".format( """
**dict(
func_name=self.func_name,
fail=sub["fail"],
params=params,
func_args=self.format_c_function_args(inp, out),
define_macros=define_macros,
undef_macros=undef_macros,
)
)
else: else:
if "code" in self.code_sections: if "code" in self.code_sections:
op_code = self.code_sections["code"] op_code = self.code_sections["code"]
......
...@@ -262,9 +262,10 @@ class Params(dict): ...@@ -262,9 +262,10 @@ class Params(dict):
self.__dict__.update(__params_type__=params_type, __signatures__=None) self.__dict__.update(__params_type__=params_type, __signatures__=None)
def __repr__(self): def __repr__(self):
return "Params({})".format( args = ", ".join(
", ".join((f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self)) (f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self)
) )
return f"Params({args})"
def __getattr__(self, key): def __getattr__(self, key):
if key not in self: if key not in self:
...@@ -422,9 +423,10 @@ class ParamsType(CType): ...@@ -422,9 +423,10 @@ class ParamsType(CType):
return super().__getattr__(self, key) return super().__getattr__(self, key)
def __repr__(self): def __repr__(self):
return "ParamsType<{}>".format( args = ", ".join(
", ".join((f"{self.fields[i]}:{self.types[i]}") for i in range(self.length)) f"{self.fields[i]}:{self.types[i]}" for i in range(self.length)
) )
return f"ParamsType<{args}>"
def __eq__(self, other): def __eq__(self, other):
return ( return (
...@@ -730,11 +732,15 @@ class ParamsType(CType): ...@@ -730,11 +732,15 @@ class ParamsType(CType):
struct_init = "\n".join(c_init_list) struct_init = "\n".join(c_init_list)
struct_cleanup = "\n".join(c_cleanup_list) struct_cleanup = "\n".join(c_cleanup_list)
struct_extract = "\n\n".join(c_extract_list) struct_extract = "\n\n".join(c_extract_list)
struct_extract_method = """ args = "\n".join(
f"case {i}: extract_{self.fields[i]}(object); break;"
for i in range(self.length)
)
struct_extract_method = f"""
void extract(PyObject* object, int field_pos) {{ void extract(PyObject* object, int field_pos) {{
switch(field_pos) {{ switch(field_pos) {{
// Extraction cases. // Extraction cases.
{} {args}
// Default case. // Default case.
default: default:
PyErr_Format(PyExc_TypeError, "ParamsType: no extraction defined for a field %d.", field_pos); PyErr_Format(PyExc_TypeError, "ParamsType: no extraction defined for a field %d.", field_pos);
...@@ -742,13 +748,8 @@ class ParamsType(CType): ...@@ -742,13 +748,8 @@ class ParamsType(CType):
break; break;
}} }}
}} }}
""".format( """
"\n".join( final_struct_code = f"""
("case %d: extract_%s(object); break;" % (i, self.fields[i]))
for i in range(self.length)
)
)
final_struct_code = """
/** ParamsType {struct_name} **/ /** ParamsType {struct_name} **/
#ifndef {struct_name_defined} #ifndef {struct_name_defined}
#define {struct_name_defined} #define {struct_name_defined}
...@@ -790,17 +791,7 @@ class ParamsType(CType): ...@@ -790,17 +791,7 @@ class ParamsType(CType):
}}; }};
#endif #endif
/** End ParamsType {struct_name} **/ /** End ParamsType {struct_name} **/
""".format( """
**dict(
struct_name_defined=struct_name_defined,
struct_name=struct_name,
struct_declare=struct_declare,
struct_init=struct_init,
struct_cleanup=struct_cleanup,
struct_extract=struct_extract,
struct_extract_method=struct_extract_method,
)
)
return [*sorted(c_support_code_set), final_struct_code] return [*sorted(c_support_code_set), final_struct_code]
...@@ -813,9 +804,9 @@ class ParamsType(CType): ...@@ -813,9 +804,9 @@ class ParamsType(CType):
# pointers. # pointers.
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
return """ return f"""
{struct_name}* {name}; {self.name}* {name};
""".format(**dict(struct_name=self.name, name=name)) """
def c_init(self, name, sub): def c_init(self, name, sub):
# NB: It seems c_init() is not called for an op param. # NB: It seems c_init() is not called for an op param.
...@@ -831,40 +822,33 @@ class ParamsType(CType): ...@@ -831,40 +822,33 @@ class ParamsType(CType):
""" """
def c_extract(self, name, sub, check_input=True, **kwargs): def c_extract(self, name, sub, check_input=True, **kwargs):
return """ fields_list = ", ".join(f'"{x}"' for x in self.fields)
return f"""
/* Seems c_init() is not called for a op param. So I call `new` here. */ /* Seems c_init() is not called for a op param. So I call `new` here. */
{name} = new {struct_name}; {name} = new {self.name};
{{ // This need a separate namespace for Clinker {{ // This need a separate namespace for Clinker
const char* fields[] = {{{fields_list}}}; const char* fields[] = {{{fields_list}}};
if (py_{name} == Py_None) {{ if (py_{name} == Py_None) {{
PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None."); PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None.");
{fail} {sub['fail']}
}} }}
for (int i = 0; i < {length}; ++i) {{ for (int i = 0; i < {self.length}; ++i) {{
PyObject* o = PyDict_GetItemString(py_{name}, fields[i]); PyObject* o = PyDict_GetItemString(py_{name}, fields[i]);
if (o == NULL) {{ if (o == NULL) {{
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute \\"%s\\" in object.", fields[i]); PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute \\"%s\\" in object.", fields[i]);
{fail} {sub['fail']}
}} }}
{name}->extract(o, i); {name}->extract(o, i);
if ({name}->errorOccurred()) {{ if ({name}->errorOccurred()) {{
/* The extract code from attribute type should have already raised a Python exception, /* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */ * so we just print the attribute name in stderr. */
fprintf(stderr, "\\nParamsType: error when extracting value for attribute \\"%s\\".\\n", fields[i]); fprintf(stderr, "\\nParamsType: error when extracting value for attribute \\"%s\\".\\n", fields[i]);
{fail} {sub['fail']}
}} }}
}} }}
}} }}
""".format( """
**dict(
name=name,
struct_name=self.name,
length=self.length,
fail=sub["fail"],
fields_list='"{}"'.format('", "'.join(self.fields)),
)
)
def c_sync(self, name, sub): def c_sync(self, name, sub):
# FIXME: Looks like we need to decrement a reference count our two. # FIXME: Looks like we need to decrement a reference count our two.
......
...@@ -98,12 +98,12 @@ class Generic(CType, Singleton): ...@@ -98,12 +98,12 @@ class Generic(CType, Singleton):
""" """
def c_sync(self, name, sub): def c_sync(self, name, sub):
return """ return f"""
assert(py_{name}->ob_refcnt > 1); assert(py_{name}->ob_refcnt > 1);
Py_DECREF(py_{name}); Py_DECREF(py_{name});
py_{name} = {name} ? {name} : Py_None; py_{name} = {name} ? {name} : Py_None;
Py_INCREF(py_{name}); Py_INCREF(py_{name});
""".format(**locals()) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
...@@ -190,18 +190,19 @@ class CDataType(CType[D]): ...@@ -190,18 +190,19 @@ class CDataType(CType[D]):
return data return data
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
return """ return f"""
{ctype} {name}; {self.ctype} {name};
""".format(**dict(ctype=self.ctype, name=name)) """
def c_init(self, name, sub): def c_init(self, name, sub):
return f"{name} = NULL;" return f"{name} = NULL;"
def c_extract(self, name, sub, check_input=True, **kwargs): def c_extract(self, name, sub, check_input=True, **kwargs):
return """ fail = sub["fail"]
{name} = ({ctype})PyCapsule_GetPointer(py_{name}, NULL); return f"""
{name} = ({self.ctype})PyCapsule_GetPointer(py_{name}, NULL);
if ({name} == NULL) {fail} if ({name} == NULL) {fail}
""".format(**dict(name=name, ctype=self.ctype, fail=sub["fail"])) """
def c_sync(self, name, sub): def c_sync(self, name, sub):
freefunc = self.freefunc freefunc = self.freefunc
...@@ -478,11 +479,8 @@ class EnumType(CType, dict): ...@@ -478,11 +479,8 @@ class EnumType(CType, dict):
names_to_aliases = {constant_name: "" for constant_name in self} names_to_aliases = {constant_name: "" for constant_name in self}
for alias in self.aliases: for alias in self.aliases:
names_to_aliases[self.aliases[alias]] = f"({alias})" names_to_aliases[self.aliases[alias]] = f"({alias})"
return "{}<{}>({})".format( args = ", ".join(f"{k}{names_to_aliases[k]}:{self[k]}" for k in sorted(self))
type(self).__name__, return f"{type(self).__name__}<{self.ctype}>({args})"
self.ctype,
", ".join(f"{k}{names_to_aliases[k]}:{self[k]}" for k in sorted(self)),
)
def __getattr__(self, key): def __getattr__(self, key):
if key in self: if key in self:
...@@ -575,33 +573,27 @@ class EnumType(CType, dict): ...@@ -575,33 +573,27 @@ class EnumType(CType, dict):
This C function may be useful to retrieve some runtime information. This C function may be useful to retrieve some runtime information.
It is available in C code when pytensor flag ``config.cmodule__debug`` is set to ``True``. It is available in C code when pytensor flag ``config.cmodule__debug`` is set to ``True``.
""" """
return """ cases = "".join(
f"""
case {name}: sprintf(out, "{name}"); break;
"""
for name in self
)
return f"""
#ifdef DEBUG #ifdef DEBUG
int pytensor_enum_to_string_{cname}({ctype} in, char* out) {{ int pytensor_enum_to_string_{self.cname}({self.ctype} in, char* out) {{
int ret = 0; int ret = 0;
switch(in) {{ switch(in) {{
{cases} {cases}
default: default:
PyErr_SetString(PyExc_ValueError, "{classname}: unknown enum value."); PyErr_SetString(PyExc_ValueError, "{type(self).__name__}: unknown enum value.");
ret = -1; ret = -1;
break; break;
}} }}
return ret; return ret;
}} }}
#endif #endif
""".format( """
**dict(
cname=self.cname,
ctype=self.ctype,
classname=type(self).__name__,
cases="".join(
"""
case {name}: sprintf(out, "{name}"); break;
""".format(**dict(name=name))
for name in self
),
)
)
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
return ( return (
...@@ -625,16 +617,17 @@ class EnumType(CType, dict): ...@@ -625,16 +617,17 @@ class EnumType(CType, dict):
return "" return ""
def c_extract(self, name, sub, check_input=True, **kwargs): def c_extract(self, name, sub, check_input=True, **kwargs):
return """ fail = sub["fail"]
return f"""
if (PyInt_Check(py_{name})) {{ if (PyInt_Check(py_{name})) {{
{name} = ({ctype})PyInt_AsLong(py_{name}); {name} = ({self.ctype})PyInt_AsLong(py_{name});
}} else {{ }} else {{
{name} = ({ctype})PyFloat_AsDouble(py_{name}); {name} = ({self.ctype})PyFloat_AsDouble(py_{name});
}} }}
if (PyErr_Occurred()) {{ if (PyErr_Occurred()) {{
{fail} {fail}
}} }}
""".format(**dict(ctype=self.ctype, name=name, fail=sub["fail"])) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (2, self.ctype, self.cname, tuple(self.items())) return (2, self.ctype, self.cname, tuple(self.items()))
...@@ -754,7 +747,14 @@ class CEnumType(EnumList): ...@@ -754,7 +747,14 @@ class CEnumType(EnumList):
swapped_dict = {v: k for (k, v) in self.items()} swapped_dict = {v: k for (k, v) in self.items()}
# swapped_dict's keys are integers. # swapped_dict's keys are integers.
return """ fail = sub["fail"]
cases = "".join(
f"""
case {i}: {name} = {swapped_dict[i]}; break;
"""
for i in sorted(swapped_dict)
)
return f"""
switch(PyInt_AsLong(py_{name})) {{ switch(PyInt_AsLong(py_{name})) {{
{cases} {cases}
default: default:
...@@ -762,19 +762,7 @@ class CEnumType(EnumList): ...@@ -762,19 +762,7 @@ class CEnumType(EnumList):
{{{fail}}} {{{fail}}}
break; break;
}} }}
""".format( """
**dict(
name=name,
cases="".join(
"""
case %(i)d: %(name)s = %(constant_cname)s; break;
"""
% dict(i=i, name=name, constant_cname=swapped_dict[i])
for i in sorted(swapped_dict)
),
fail=sub["fail"],
)
)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1, super().c_code_cache_version()) return (1, super().c_code_cache_version())
...@@ -136,20 +136,19 @@ def _check_scipy_linalg_matrix(a, func_name): ...@@ -136,20 +136,19 @@ def _check_scipy_linalg_matrix(a, func_name):
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831 Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
""" """
prefix = "scipy.linalg" prefix = "scipy.linalg"
interp = (prefix, func_name)
# Unpack optional type # Unpack optional type
if isinstance(a, types.Optional): if isinstance(a, types.Optional):
a = a.type a = a.type
if not isinstance(a, types.Array): if not isinstance(a, types.Array):
msg = "{}.{}() only supported for array types".format(*interp) msg = f"{prefix}.{func_name}() only supported for array types"
raise numba.TypingError(msg, highlighting=False) raise numba.TypingError(msg, highlighting=False)
if a.ndim not in [1, 2]: if a.ndim not in [1, 2]:
msg = "{}.{}() only supported on 1d or 2d arrays, found {}.".format( msg = (
*interp, a.ndim f"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}."
) )
raise numba.TypingError(msg, highlighting=False) raise numba.TypingError(msg, highlighting=False)
if not isinstance(a.dtype, types.Float | types.Complex): if not isinstance(a.dtype, types.Float | types.Complex):
msg = "{}.{}() only supported on " "float and complex arrays.".format(*interp) msg = "{prefix}.{func_name}() only supported on float and complex arrays."
raise numba.TypingError(msg, highlighting=False) raise numba.TypingError(msg, highlighting=False)
......
...@@ -355,14 +355,10 @@ def raise_with_op( ...@@ -355,14 +355,10 @@ def raise_with_op(
+ f"\nInputs values: {scalar_values}" + f"\nInputs values: {scalar_values}"
) )
if verbosity == "high": if verbosity == "high":
detailed_err_msg += "\nInputs type_num: {}".format( inpts = [
str( getattr(getattr(i[0], "dtype", ""), "num", "") for i in thunk.inputs
[ ]
getattr(getattr(i[0], "dtype", ""), "num", "") detailed_err_msg += f"\nInputs type_num: {inpts}"
for i in thunk.inputs
]
)
)
detailed_err_msg += f"\nOutputs clients: {clients}\n" detailed_err_msg += f"\nOutputs clients: {clients}\n"
else: else:
......
...@@ -1048,10 +1048,8 @@ class DefaultPrinter(Printer): ...@@ -1048,10 +1048,8 @@ class DefaultPrinter(Printer):
if node is None: if node is None:
return leaf_printer.process(output, pstate) return leaf_printer.process(output, pstate)
with set_precedence(pstate): with set_precedence(pstate):
r = "{}({})".format( args = ", ".join(pprinter.process(input, pstate) for input in node.inputs)
str(node.op), r = f"{node.op}({args})"
", ".join(pprinter.process(input, pstate) for input in node.inputs),
)
pstate.memo[output] = r pstate.memo[output] = r
return r return r
......
...@@ -464,33 +464,32 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -464,33 +464,32 @@ class ScalarType(CType, HasDataType, HasShape):
raise NotImplementedError("float16") raise NotImplementedError("float16")
specs = self.dtype_specs() specs = self.dtype_specs()
if check_input: if check_input:
pre = """ fail = sub["fail"]
dtype = specs[1]
pyarr_type = f"Py{specs[2]}ArrType_Type"
pre = f"""
if (!PyObject_TypeCheck(py_{name}, &{pyarr_type})) if (!PyObject_TypeCheck(py_{name}, &{pyarr_type}))
{{ {{
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Scalar check failed ({dtype})"); "Scalar check failed ({dtype})");
{fail} {fail}
}} }}
""".format( """
**dict(
sub,
name=name,
dtype=specs[1],
pyarr_type=f"Py{specs[2]}ArrType_Type",
)
)
else: else:
pre = "" pre = ""
return ( return (
pre pre
+ """ + f"""
PyArray_ScalarAsCtype(py_{name}, &{name}); PyArray_ScalarAsCtype(py_{name}, &{name});
""".format(**dict(sub, name=name)) """
) )
def c_sync(self, name, sub): def c_sync(self, name, sub):
specs = self.dtype_specs() specs = self.dtype_specs()
return """ fail = sub["fail"]
dtype = specs[1]
cls = specs[2]
return f"""
Py_XDECREF(py_{name}); Py_XDECREF(py_{name});
py_{name} = PyArrayScalar_New({cls}); py_{name} = PyArrayScalar_New({cls});
if (!py_{name}) if (!py_{name})
...@@ -502,7 +501,7 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -502,7 +501,7 @@ class ScalarType(CType, HasDataType, HasShape):
{fail} {fail}
}} }}
PyArrayScalar_ASSIGN(py_{name}, {cls}, {name}); PyArrayScalar_ASSIGN(py_{name}, {cls}, {name});
""".format(**dict(sub, name=name, dtype=specs[1], cls=specs[2])) """
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
return "" return ""
...@@ -1179,10 +1178,9 @@ class ScalarOp(COp): ...@@ -1179,10 +1178,9 @@ class ScalarOp(COp):
not in ("name", "_op_use_c_code", "bool", "output_types_preference") not in ("name", "_op_use_c_code", "bool", "output_types_preference")
] ]
if param: if param:
return "{}{{{}}}".format( classname = self.__class__.__name__
self.__class__.__name__, args = ", ".join(f"{k}={v}" for k, v in param)
", ".join(f"{k}={v}" for k, v in param), return f"{classname}{{{args}}}"
)
else: else:
return self.__class__.__name__ return self.__class__.__name__
......
...@@ -615,8 +615,8 @@ class Chi2SF(BinaryScalarOp): ...@@ -615,8 +615,8 @@ class Chi2SF(BinaryScalarOp):
(z,) = out (z,) = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype dtype = "npy_" + node.outputs[0].dtype
return """{z} = return f"""{z} =
({dtype}) 1 - GammaP({k}/2., {x}/2.);""".format(**locals()) ({dtype}) 1 - GammaP({k}/2., {x}/2.);"""
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
...@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp): ...@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp):
(z,) = out (z,) = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype dtype = "npy_" + node.outputs[0].dtype
return """{z} = return f"""{z} =
({dtype}) GammaP({k}, {x});""".format(**locals()) ({dtype}) GammaP({k}, {x});"""
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
...@@ -717,8 +717,8 @@ class GammaIncC(BinaryScalarOp): ...@@ -717,8 +717,8 @@ class GammaIncC(BinaryScalarOp):
(z,) = out (z,) = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype dtype = "npy_" + node.outputs[0].dtype
return """{z} = return f"""{z} =
({dtype}) GammaQ({k}, {x});""".format(**locals()) ({dtype}) GammaQ({k}, {x});"""
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
...@@ -1028,8 +1028,8 @@ class GammaU(BinaryScalarOp): ...@@ -1028,8 +1028,8 @@ class GammaU(BinaryScalarOp):
(z,) = out (z,) = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype dtype = "npy_" + node.outputs[0].dtype
return """{z} = return f"""{z} =
({dtype}) upperGamma({k}, {x});""".format(**locals()) ({dtype}) upperGamma({k}, {x});"""
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
...@@ -1064,8 +1064,8 @@ class GammaL(BinaryScalarOp): ...@@ -1064,8 +1064,8 @@ class GammaL(BinaryScalarOp):
(z,) = out (z,) = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype dtype = "npy_" + node.outputs[0].dtype
return """{z} = return f"""{z} =
({dtype}) lowerGamma({k}, {x});""".format(**locals()) ({dtype}) lowerGamma({k}, {x});"""
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
......
...@@ -1072,7 +1072,8 @@ class ScanArgs: ...@@ -1072,7 +1072,8 @@ class ScanArgs:
for p in self.field_names for p in self.field_names
if p.startswith("outer_out") if p.startswith("outer_out")
] ]
res = "ScanArgs(\n{})".format(",\n".join(inner_arg_strs)) args = ",\n".join(inner_arg_strs)
res = f"ScanArgs(\n{args})"
return res return res
def __repr__(self): def __repr__(self):
......
...@@ -3617,7 +3617,8 @@ class StructuredDotGradCSC(COp): ...@@ -3617,7 +3617,8 @@ class StructuredDotGradCSC(COp):
if node.inputs[3].type.dtype in ("complex64", "complex128"): if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for g_ab") raise NotImplementedError("Complex types are not supported for g_ab")
return """ fail = sub["fail"]
return f"""
if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}} if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}}
if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}} if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}}
if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}} if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}}
...@@ -3689,7 +3690,7 @@ class StructuredDotGradCSC(COp): ...@@ -3689,7 +3690,7 @@ class StructuredDotGradCSC(COp):
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
...@@ -3750,7 +3751,8 @@ class StructuredDotGradCSR(COp): ...@@ -3750,7 +3751,8 @@ class StructuredDotGradCSR(COp):
if node.inputs[3].type.dtype in ("complex64", "complex128"): if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for g_ab") raise NotImplementedError("Complex types are not supported for g_ab")
return """ fail = sub["fail"]
return f"""
if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}} if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}}
if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}} if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}}
if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}} if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}}
...@@ -3823,7 +3825,7 @@ class StructuredDotGradCSR(COp): ...@@ -3823,7 +3825,7 @@ class StructuredDotGradCSR(COp):
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
......
...@@ -137,7 +137,7 @@ class AddSD_ccode(_NoPythonCOp): ...@@ -137,7 +137,7 @@ class AddSD_ccode(_NoPythonCOp):
inplace = int(self.inplace) inplace = int(self.inplace)
format = {"csc": 0, "csr": 1}[self.format] format = {"csc": 0, "csr": 1}[self.format]
out_typenum = node.outputs[0].type.dtype_specs()[2] out_typenum = node.outputs[0].type.dtype_specs()[2]
code = """ code = f"""
Py_XDECREF({z}); Py_XDECREF({z});
if (!{inplace}){{ if (!{inplace}){{
if(PyArray_TYPE({y}) != {out_typenum}){{ if(PyArray_TYPE({y}) != {out_typenum}){{
...@@ -179,7 +179,7 @@ class AddSD_ccode(_NoPythonCOp): ...@@ -179,7 +179,7 @@ class AddSD_ccode(_NoPythonCOp):
}} }}
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
return code return code
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
...@@ -310,7 +310,8 @@ class StructuredDotCSC(COp): ...@@ -310,7 +310,8 @@ class StructuredDotCSC(COp):
typenum_a_val = node.inputs[0].type.dtype_specs()[2] # retrieve dtype number typenum_a_val = node.inputs[0].type.dtype_specs()[2] # retrieve dtype number
typenum_b = node.inputs[4].type.dtype_specs()[2] # retrieve dtype number typenum_b = node.inputs[4].type.dtype_specs()[2] # retrieve dtype number
rval = """ fail = sub["fail"]
rval = f"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}} if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}} if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
...@@ -430,7 +431,7 @@ class StructuredDotCSC(COp): ...@@ -430,7 +431,7 @@ class StructuredDotCSC(COp):
}} }}
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
return rval return rval
...@@ -517,7 +518,8 @@ class StructuredDotCSR(COp): ...@@ -517,7 +518,8 @@ class StructuredDotCSR(COp):
if node.inputs[3].type.dtype in ("complex64", "complex128"): if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b") raise NotImplementedError("Complex types are not supported for b")
return """ fail = sub["fail"]
return f"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}} if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}} if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}} if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}}
...@@ -609,7 +611,7 @@ class StructuredDotCSR(COp): ...@@ -609,7 +611,7 @@ class StructuredDotCSR(COp):
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (2,)
...@@ -756,7 +758,8 @@ class UsmmCscDense(_NoPythonCOp): ...@@ -756,7 +758,8 @@ class UsmmCscDense(_NoPythonCOp):
inplace = int(self.inplace) inplace = int(self.inplace)
rval = """ fail = sub["fail"]
rval = f"""
if (PyArray_NDIM({x_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_val) != 1"); {fail};}} if (PyArray_NDIM({x_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_val) != 1"); {fail};}}
if (PyArray_NDIM({x_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_ind) != 1"); {fail};}} if (PyArray_NDIM({x_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_ind) != 1"); {fail};}}
...@@ -888,7 +891,7 @@ class UsmmCscDense(_NoPythonCOp): ...@@ -888,7 +891,7 @@ class UsmmCscDense(_NoPythonCOp):
}} }}
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
return rval return rval
...@@ -985,7 +988,8 @@ class CSMGradC(_NoPythonCOp): ...@@ -985,7 +988,8 @@ class CSMGradC(_NoPythonCOp):
if node.inputs[3].type.dtype in ("complex64", "complex128"): if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b_val") raise NotImplementedError("Complex types are not supported for b_val")
return """ fail = sub["fail"]
return f"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}} if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}} if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}} if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}}
...@@ -1079,7 +1083,7 @@ class CSMGradC(_NoPythonCOp): ...@@ -1079,7 +1083,7 @@ class CSMGradC(_NoPythonCOp):
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (3,)
...@@ -1165,7 +1169,8 @@ class MulSDCSC(_NoPythonCOp): ...@@ -1165,7 +1169,8 @@ class MulSDCSC(_NoPythonCOp):
if node.inputs[3].type.dtype in ("complex64", "complex128"): if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b") raise NotImplementedError("Complex types are not supported for b")
return """ fail = sub["fail"]
return f"""
if (PyArray_NDIM({_b}) != 2) {{ if (PyArray_NDIM({_b}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
{fail};}} {fail};}}
...@@ -1231,7 +1236,7 @@ class MulSDCSC(_NoPythonCOp): ...@@ -1231,7 +1236,7 @@ class MulSDCSC(_NoPythonCOp):
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
...@@ -1302,7 +1307,8 @@ class MulSDCSR(_NoPythonCOp): ...@@ -1302,7 +1307,8 @@ class MulSDCSR(_NoPythonCOp):
if node.inputs[3].type.dtype in ("complex64", "complex128"): if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b") raise NotImplementedError("Complex types are not supported for b")
return """ fail = sub["fail"]
return f"""
if (PyArray_NDIM({_b}) != 2) {{ if (PyArray_NDIM({_b}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
{fail};}} {fail};}}
...@@ -1368,7 +1374,7 @@ class MulSDCSR(_NoPythonCOp): ...@@ -1368,7 +1374,7 @@ class MulSDCSR(_NoPythonCOp):
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
...@@ -1491,7 +1497,8 @@ class MulSVCSR(_NoPythonCOp): ...@@ -1491,7 +1497,8 @@ class MulSVCSR(_NoPythonCOp):
if node.inputs[3].type.dtype in ("complex64", "complex128"): if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b") raise NotImplementedError("Complex types are not supported for b")
return """ fail = sub["fail"]
return f"""
if (PyArray_NDIM({_b}) != 1) {{ if (PyArray_NDIM({_b}) != 1) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
{fail}; {fail};
...@@ -1553,7 +1560,7 @@ class MulSVCSR(_NoPythonCOp): ...@@ -1553,7 +1560,7 @@ class MulSVCSR(_NoPythonCOp):
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
...@@ -1663,7 +1670,8 @@ class StructuredAddSVCSR(_NoPythonCOp): ...@@ -1663,7 +1670,8 @@ class StructuredAddSVCSR(_NoPythonCOp):
if node.inputs[3].type.dtype in ("complex64", "complex128"): if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b") raise NotImplementedError("Complex types are not supported for b")
return """ fail = sub["fail"]
return f"""
if (PyArray_NDIM({_b}) != 1) {{ if (PyArray_NDIM({_b}) != 1) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
{fail}; {fail};
...@@ -1732,7 +1740,7 @@ class StructuredAddSVCSR(_NoPythonCOp): ...@@ -1732,7 +1740,7 @@ class StructuredAddSVCSR(_NoPythonCOp):
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
def __str__(self): def __str__(self):
return self.__class__.__name__ return self.__class__.__name__
...@@ -1904,7 +1912,8 @@ class SamplingDotCSR(_NoPythonCOp): ...@@ -1904,7 +1912,8 @@ class SamplingDotCSR(_NoPythonCOp):
typenum_zi = TensorType(node.outputs[1].dtype, []).dtype_specs()[2] typenum_zi = TensorType(node.outputs[1].dtype, []).dtype_specs()[2]
typenum_zp = TensorType(node.outputs[2].dtype, []).dtype_specs()[2] typenum_zp = TensorType(node.outputs[2].dtype, []).dtype_specs()[2]
rval = """ fail = sub["fail"]
rval = f"""
if (PyArray_NDIM({x}) != 2) {{ if (PyArray_NDIM({x}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); {fail};}} PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); {fail};}}
if (PyArray_NDIM({y}) != 2) {{ if (PyArray_NDIM({y}) != 2) {{
...@@ -2024,7 +2033,7 @@ PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); {fail};}} ...@@ -2024,7 +2033,7 @@ PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); {fail};}}
}} }}
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
return rval return rval
......
...@@ -597,12 +597,12 @@ class TensorFromScalar(COp): ...@@ -597,12 +597,12 @@ class TensorFromScalar(COp):
(z,) = outputs (z,) = outputs
fail = sub["fail"] fail = sub["fail"]
return """ return f"""
{z} = (PyArrayObject*)PyArray_FromScalar(py_{x}, NULL); {z} = (PyArrayObject*)PyArray_FromScalar(py_{x}, NULL);
if({z} == NULL){{ if({z} == NULL){{
{fail}; {fail};
}} }}
""".format(**locals()) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (2,)
...@@ -646,10 +646,9 @@ class ScalarFromTensor(COp): ...@@ -646,10 +646,9 @@ class ScalarFromTensor(COp):
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
(z,) = outputs (z,) = outputs
fail = sub["fail"] return f"""
return """
{z} = ((dtype_{x}*)(PyArray_DATA({x})))[0]; {z} = ((dtype_{x}*)(PyArray_DATA({x})))[0];
""".format(**locals()) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
...@@ -1836,18 +1835,18 @@ class MakeVector(COp): ...@@ -1836,18 +1835,18 @@ class MakeVector(COp):
assert self.dtype == node.inputs[0].dtype assert self.dtype == node.inputs[0].dtype
out_num = f"PyArray_TYPE({inp[0]})" out_num = f"PyArray_TYPE({inp[0]})"
ret = """ ret = f"""
npy_intp dims[1]; npy_intp dims[1];
dims[0] = {out_shape}; dims[0] = {out_shape};
if(!{out} || PyArray_DIMS({out})[0] != {out_shape}){{ if(!{out} || PyArray_DIMS({out})[0] != {out_shape}){{
Py_XDECREF({out}); Py_XDECREF({out});
{out} = (PyArrayObject*)PyArray_EMPTY(1, dims, {out_num}, 0); {out} = (PyArrayObject*)PyArray_EMPTY(1, dims, {out_num}, 0);
}} }}
""".format(**locals()) """
for idx, i in enumerate(inp): for idx, i in enumerate(inp):
ret += """ ret += f"""
*(({out_dtype} *)PyArray_GETPTR1({out}, {idx})) = *(({out_dtype} *) PyArray_DATA({i})); *(({out_dtype} *)PyArray_GETPTR1({out}, {idx})) = *(({out_dtype} *) PyArray_DATA({i}));
""".format(**locals()) """
return ret return ret
def infer_shape(self, fgraph, node, ishapes): def infer_shape(self, fgraph, node, ishapes):
...@@ -2225,7 +2224,7 @@ class Split(COp): ...@@ -2225,7 +2224,7 @@ class Split(COp):
splits_dtype = node.inputs[2].type.dtype_specs()[1] splits_dtype = node.inputs[2].type.dtype_specs()[1]
expected_splits_count = self.len_splits expected_splits_count = self.len_splits
return """ return f"""
int ndim = PyArray_NDIM({x}); int ndim = PyArray_NDIM({x});
int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0)); int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0));
int splits_count = PyArray_DIM({splits}, 0); int splits_count = PyArray_DIM({splits}, 0);
...@@ -2322,7 +2321,7 @@ class Split(COp): ...@@ -2322,7 +2321,7 @@ class Split(COp):
}} }}
free(split_dims); free(split_dims);
""".format(**locals()) """
class Join(COp): class Join(COp):
...@@ -2373,10 +2372,9 @@ class Join(COp): ...@@ -2373,10 +2372,9 @@ class Join(COp):
if self.view == -1: if self.view == -1:
return self.__class__.__name__ return self.__class__.__name__
else: else:
return "{}{{{}}}".format( classname = self.__class__.__name__
self.__class__.__name__, args = ", ".join(f"{p}={getattr(self, p)!r}" for p in self.__props__)
", ".join(f"{p}={getattr(self, p)!r}" for p in self.__props__), return f"{classname}{{{args}}}"
)
def __setstate__(self, d): def __setstate__(self, d):
self.__dict__.update(d) self.__dict__.update(d)
...@@ -2538,7 +2536,7 @@ class Join(COp): ...@@ -2538,7 +2536,7 @@ class Join(COp):
copy_inputs_to_list = "\n".join(copy_to_list) copy_inputs_to_list = "\n".join(copy_to_list)
n = len(tens) n = len(tens)
code = """ code = f"""
int axis = (({adtype} *)PyArray_DATA({axis}))[0]; int axis = (({adtype} *)PyArray_DATA({axis}))[0];
PyObject* list = PyList_New({l}); PyObject* list = PyList_New({l});
{copy_inputs_to_list} {copy_inputs_to_list}
...@@ -2570,7 +2568,7 @@ class Join(COp): ...@@ -2570,7 +2568,7 @@ class Join(COp):
{fail} {fail}
}} }}
}} }}
""".format(**locals()) """
return code return code
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
...@@ -4192,18 +4190,14 @@ class AllocEmpty(COp): ...@@ -4192,18 +4190,14 @@ class AllocEmpty(COp):
params = sub["params"] params = sub["params"]
str = f"npy_intp dims[{nd}];\n" str = f"npy_intp dims[{nd}];\n"
for idx, sh in enumerate(shps): for idx, sh in enumerate(shps):
str += ( str += f"dims[{idx}] = ((npy_intp)((dtype_{sh}*) PyArray_DATA({sh}))[0]);\n"
"dims[{idx}] ="
"((npy_intp)((dtype_{sh}*)"
" PyArray_DATA({sh}))[0]);\n".format(**locals())
)
# Validate that the output storage exists # Validate that the output storage exists
str += f"if({out}==NULL\n" str += f"if({out}==NULL\n"
for idx, sh in enumerate(shps): for idx, sh in enumerate(shps):
str += f"||PyArray_DIMS({out})[{idx}]!=dims[{idx}]" str += f"||PyArray_DIMS({out})[{idx}]!=dims[{idx}]"
str += """){{ str += f"""){{
/* Reference received to invalid output variable. /* Reference received to invalid output variable.
Decrease received reference's ref count and allocate new Decrease received reference's ref count and allocate new
output variable */ output variable */
...@@ -4218,7 +4212,7 @@ class AllocEmpty(COp): ...@@ -4218,7 +4212,7 @@ class AllocEmpty(COp):
{fail}; {fail};
}} }}
}} }}
""".format(**locals()) """
return str return str
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
......
...@@ -1806,19 +1806,12 @@ class BatchedDot(COp): ...@@ -1806,19 +1806,12 @@ class BatchedDot(COp):
strides = f"PyArray_STRIDES({var})" strides = f"PyArray_STRIDES({var})"
if ndim == 1: if ndim == 1:
return f"{strides}[0] == type_size" return f"{strides}[0] == type_size"
return " && ".join( ands = " && ".join(
[ f"{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0"
" && ".join( for i in range(1, ndim)
f"{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0"
for i in range(1, ndim)
),
"({})".format(
" || ".join(
f"{strides}[{i}] == type_size" for i in range(1, ndim)
)
),
]
) )
ors = " || ".join(f"{strides}[{i}] == type_size" for i in range(1, ndim))
return f"{ands} && ({ors})"
x_ndim, y_ndim, z_ndim = ( x_ndim, y_ndim, z_ndim = (
node.inputs[0].ndim, node.inputs[0].ndim,
...@@ -1838,7 +1831,7 @@ class BatchedDot(COp): ...@@ -1838,7 +1831,7 @@ class BatchedDot(COp):
) )
z_shape = ", ".join(z_dims) z_shape = ", ".join(z_dims)
z_contiguous = contiguous(_z, z_ndim) z_contiguous = contiguous(_z, z_ndim)
allocate = """ allocate = f"""
if (NULL == {_z} || !({z_shape_correct}) || !({z_contiguous})) if (NULL == {_z} || !({z_shape_correct}) || !({z_contiguous}))
{{ {{
npy_intp dims[{z_ndim}] = {{{z_shape}}}; npy_intp dims[{z_ndim}] = {{{z_shape}}};
...@@ -1851,14 +1844,14 @@ class BatchedDot(COp): ...@@ -1851,14 +1844,14 @@ class BatchedDot(COp):
{fail} {fail}
}} }}
}} }}
""".format(**locals()) """
# code to reallocate inputs contiguously if necessary # code to reallocate inputs contiguously if necessary
contiguate = [] contiguate = []
for var, ndim in [(_x, x_ndim), (_y, y_ndim)]: for var, ndim in [(_x, x_ndim), (_y, y_ndim)]:
_contiguous = contiguous(var, ndim) _contiguous = contiguous(var, ndim)
contiguate.append( contiguate.append(
""" f"""
if (!({_contiguous})) {{ if (!({_contiguous})) {{
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy({var}); PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy({var});
if (!_copy) if (!_copy)
...@@ -1866,11 +1859,11 @@ class BatchedDot(COp): ...@@ -1866,11 +1859,11 @@ class BatchedDot(COp):
Py_XDECREF({var}); Py_XDECREF({var});
{var} = _copy; {var} = _copy;
}} }}
""".format(**locals()) """
) )
contiguate = "\n".join(contiguate) contiguate = "\n".join(contiguate)
return """ return f"""
int type_num = PyArray_DESCR({_x})->type_num; int type_num = PyArray_DESCR({_x})->type_num;
int type_size = PyArray_DESCR({_x})->elsize; // in bytes int type_size = PyArray_DESCR({_x})->elsize; // in bytes
...@@ -1927,7 +1920,7 @@ class BatchedDot(COp): ...@@ -1927,7 +1920,7 @@ class BatchedDot(COp):
}} }}
break; break;
}} }}
""".format(**locals()) """
def c_code_cache_version(self): def c_code_cache_version(self):
from pytensor.tensor.blas_headers import blas_header_version from pytensor.tensor.blas_headers import blas_header_version
......
...@@ -33,7 +33,7 @@ class BaseBLAS(COp): ...@@ -33,7 +33,7 @@ class BaseBLAS(COp):
def ger_c_code(A, a, x, y, Z, fail, params): def ger_c_code(A, a, x, y, Z, fail, params):
return """ return f"""
int elemsize ; int elemsize ;
...@@ -309,7 +309,7 @@ def ger_c_code(A, a, x, y, Z, fail, params): ...@@ -309,7 +309,7 @@ def ger_c_code(A, a, x, y, Z, fail, params):
}} }}
}} }}
""".format(**locals()) """
class CGer(BaseBLAS, Ger): class CGer(BaseBLAS, Ger):
......
...@@ -1066,7 +1066,7 @@ def blas_header_version(): ...@@ -1066,7 +1066,7 @@ def blas_header_version():
def ____gemm_code(check_ab, a_init, b_init): def ____gemm_code(check_ab, a_init, b_init):
mod = "%" mod = "%"
return """ return f"""
const char * error_string = NULL; const char * error_string = NULL;
int type_num = PyArray_DESCR(_x)->type_num; int type_num = PyArray_DESCR(_x)->type_num;
...@@ -1203,4 +1203,4 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -1203,4 +1203,4 @@ def ____gemm_code(check_ab, a_init, b_init):
return -1; return -1;
/* v 1 */ /* v 1 */
""".format(**locals()) """
...@@ -302,10 +302,7 @@ class DimShufflePrinter(Printer): ...@@ -302,10 +302,7 @@ class DimShufflePrinter(Printer):
return pstate.pprinter.process(r) return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))): if list(new_order) == list(reversed(range(r.type.ndim))):
return f"{pstate.pprinter.process(r)}.T" return f"{pstate.pprinter.process(r)}.T"
return "DimShuffle{{{}}}({})".format( return f"DimShuffle{{{', '.join(str(o) for o in new_order)}}}({pstate.pprinter.process(r)})"
", ".join(map(str, new_order)),
pstate.pprinter.process(r),
)
def process(self, r, pstate): def process(self, r, pstate):
if r.owner is None: if r.owner is None:
...@@ -929,13 +926,13 @@ class Elemwise(OpenMPOp): ...@@ -929,13 +926,13 @@ class Elemwise(OpenMPOp):
# We make the output point to the corresponding input and # We make the output point to the corresponding input and
# decrease the reference of whatever the output contained # decrease the reference of whatever the output contained
# prior to this # prior to this
alloc += """ alloc += f"""
if ({oname}) {{ if ({oname}) {{
Py_XDECREF({oname}); Py_XDECREF({oname});
}} }}
{oname} = {iname}; {oname} = {iname};
Py_XINCREF({oname}); Py_XINCREF({oname});
""".format(**locals()) """
# We alias the scalar variables # We alias the scalar variables
defines += f"#define {oname}_i {iname}_i\n" defines += f"#define {oname}_i {iname}_i\n"
undefs += f"#undef {oname}_i\n" undefs += f"#undef {oname}_i\n"
...@@ -958,13 +955,13 @@ class Elemwise(OpenMPOp): ...@@ -958,13 +955,13 @@ class Elemwise(OpenMPOp):
[f"{s}_i" for s in onames], [f"{s}_i" for s in onames],
dict(sub, fail=fail), dict(sub, fail=fail),
) )
code = """ code = f"""
{{ {{
{defines} {defines}
{task_code} {task_code}
{undefs} {undefs}
}} }}
""".format(**locals()) """
loop_orders = orders + [list(range(nnested))] * len(real_onames) loop_orders = orders + [list(range(nnested))] * len(real_onames)
dtypes = idtypes + list(real_odtypes) dtypes = idtypes + list(real_odtypes)
...@@ -994,19 +991,17 @@ class Elemwise(OpenMPOp): ...@@ -994,19 +991,17 @@ class Elemwise(OpenMPOp):
if index != "x": if index != "x":
preloops.setdefault(j, "") preloops.setdefault(j, "")
preloops[j] += ( preloops[j] += (
"%(lv{i})s_iter = ({dtype}*)" f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n"
"(PyArray_DATA(%(lv{i})s));\n".format(**locals())
) % sub ) % sub
break break
else: # all broadcastable else: # all broadcastable
preloops.setdefault(0, "") preloops.setdefault(0, "")
preloops[0] += ( preloops[0] += (
"%(lv{i})s_iter = ({dtype}*)" f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n"
"(PyArray_DATA(%(lv{i})s));\n".format(**locals())
) % sub ) % sub
init_array = preloops.get(0, " ") init_array = preloops.get(0, " ")
loop = """ loop = f"""
{{ {{
{defines} {defines}
{init_array} {init_array}
...@@ -1014,7 +1009,7 @@ class Elemwise(OpenMPOp): ...@@ -1014,7 +1009,7 @@ class Elemwise(OpenMPOp):
{task_code} {task_code}
{undefs} {undefs}
}} }}
""".format(**locals()) """
else: else:
loop = cgen.make_loop( loop = cgen.make_loop(
loop_orders=loop_orders, loop_orders=loop_orders,
...@@ -1074,25 +1069,25 @@ class Elemwise(OpenMPOp): ...@@ -1074,25 +1069,25 @@ class Elemwise(OpenMPOp):
index = "" index = ""
for x, var in zip(inames + onames, inputs + node.outputs): for x, var in zip(inames + onames, inputs + node.outputs):
if not all(s == 1 for s in var.type.shape): if not all(s == 1 for s in var.type.shape):
contig += """ contig += f"""
dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x}); dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x});
""".format(**locals()) """
index += """ index += f"""
dtype_{x}& {x}_i = {x}_ptr[i]; dtype_{x}& {x}_i = {x}_ptr[i];
""".format(**locals()) """
else: else:
contig += """ contig += f"""
dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0]; dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0];
""".format(**locals()) """
if self.openmp: if self.openmp:
contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)}) contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)})
""" """
contig += """ contig += f"""
for(int i=0; i<n; i++){{ for(int i=0; i<n; i++){{
{index} {index}
{task_code}; {task_code};
}} }}
""".format(**locals()) """
if contig is not None: if contig is not None:
z = list(zip(inames + onames, inputs + node.outputs)) z = list(zip(inames + onames, inputs + node.outputs))
all_broadcastable = all(s == 1 for s in var.type.shape) all_broadcastable = all(s == 1 for s in var.type.shape)
...@@ -1106,13 +1101,13 @@ class Elemwise(OpenMPOp): ...@@ -1106,13 +1101,13 @@ class Elemwise(OpenMPOp):
for arr, var in z for arr, var in z
if not all_broadcastable if not all_broadcastable
) )
loop = """ loop = f"""
if(({cond1}) || ({cond2})){{ if(({cond1}) || ({cond2})){{
{contig} {contig}
}}else{{ }}else{{
{loop} {loop}
}} }}
""".format(**locals()) """
return decl, checks, alloc, loop, "" return decl, checks, alloc, loop, ""
def c_code(self, node, nodename, inames, onames, sub): def c_code(self, node, nodename, inames, onames, sub):
......
...@@ -170,12 +170,14 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): ...@@ -170,12 +170,14 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX") type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX")
nd = len(loop_orders[0]) nd = len(loop_orders[0])
init_dims = compute_output_dims_lengths("dims", loop_orders, sub) init_dims = compute_output_dims_lengths("dims", loop_orders, sub)
olv = sub["olv"]
fail = sub["fail"]
# TODO: it would be interesting to allocate the output in such a # TODO: it would be interesting to allocate the output in such a
# way that its contiguous dimensions match one of the input's # way that its contiguous dimensions match one of the input's
# contiguous dimensions, or the dimension with the smallest # contiguous dimensions, or the dimension with the smallest
# stride. Right now, it is allocated to be C_CONTIGUOUS. # stride. Right now, it is allocated to be C_CONTIGUOUS.
return """ return f"""
{{ {{
npy_intp dims[{nd}]; npy_intp dims[{nd}];
//npy_intp* dims = (npy_intp*)malloc({nd} * sizeof(npy_intp)); //npy_intp* dims = (npy_intp*)malloc({nd} * sizeof(npy_intp));
...@@ -203,7 +205,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): ...@@ -203,7 +205,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
{fail} {fail}
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
......
...@@ -68,7 +68,7 @@ class CpuContiguous(COp): ...@@ -68,7 +68,7 @@ class CpuContiguous(COp):
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
(x,) = inames (x,) = inames
(y,) = onames (y,) = onames
code = """ code = f"""
if (!PyArray_CHKFLAGS({x}, NPY_ARRAY_C_CONTIGUOUS)){{ if (!PyArray_CHKFLAGS({x}, NPY_ARRAY_C_CONTIGUOUS)){{
// check to see if output is contiguous first // check to see if output is contiguous first
if ({y} != NULL && if ({y} != NULL &&
...@@ -86,7 +86,7 @@ class CpuContiguous(COp): ...@@ -86,7 +86,7 @@ class CpuContiguous(COp):
Py_XDECREF({y}); Py_XDECREF({y});
{y} = {x}; {y} = {x};
}} }}
""".format(**locals()) """
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -161,13 +161,13 @@ class SearchsortedOp(COp): ...@@ -161,13 +161,13 @@ class SearchsortedOp(COp):
def c_init_code_struct(self, node, name, sub): def c_init_code_struct(self, node, name, sub):
side = sub["params"] side = sub["params"]
fail = sub["fail"] fail = sub["fail"]
return """ return f"""
PyObject* tmp_{name} = PyUnicode_FromString("right"); PyObject* tmp_{name} = PyUnicode_FromString("right");
if (tmp_{name} == NULL) if (tmp_{name} == NULL)
{fail}; {fail};
right_{name} = PyUnicode_Compare({side}, tmp_{name}); right_{name} = PyUnicode_Compare({side}, tmp_{name});
Py_DECREF(tmp_{name}); Py_DECREF(tmp_{name});
""".format(**locals()) """
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
sorter = None sorter = None
...@@ -180,7 +180,7 @@ class SearchsortedOp(COp): ...@@ -180,7 +180,7 @@ class SearchsortedOp(COp):
(z,) = onames (z,) = onames
fail = sub["fail"] fail = sub["fail"]
return """ return f"""
Py_XDECREF({z}); Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SearchSorted({x}, (PyObject*) {v}, {z} = (PyArrayObject*) PyArray_SearchSorted({x}, (PyObject*) {v},
right_{name} ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) {sorter}); right_{name} ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) {sorter});
...@@ -191,7 +191,7 @@ class SearchsortedOp(COp): ...@@ -191,7 +191,7 @@ class SearchsortedOp(COp):
Py_XDECREF({z}); Py_XDECREF({z});
{z} = (PyArrayObject*) tmp; {z} = (PyArrayObject*) tmp;
}} }}
""".format(**locals()) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (2,)
...@@ -348,11 +348,10 @@ class CumOp(COp): ...@@ -348,11 +348,10 @@ class CumOp(COp):
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
(x,) = inames (x,) = inames
(z,) = onames (z,) = onames
axis = self.axis
fail = sub["fail"] fail = sub["fail"]
params = sub["params"] params = sub["params"]
code = """ code = f"""
int axis = {params}->c_axis; int axis = {params}->c_axis;
if (axis == 0 && PyArray_NDIM({x}) == 1) if (axis == 0 && PyArray_NDIM({x}) == 1)
axis = NPY_MAXDIMS; axis = NPY_MAXDIMS;
...@@ -389,7 +388,7 @@ class CumOp(COp): ...@@ -389,7 +388,7 @@ class CumOp(COp):
// Because PyArray_CumSum/CumProd returns a newly created reference on t. // Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t); Py_XDECREF(t);
}} }}
""".format(**locals()) """
return code return code
......
...@@ -215,14 +215,14 @@ class Argmax(COp): ...@@ -215,14 +215,14 @@ class Argmax(COp):
if len(self.axis) != 1: if len(self.axis) != 1:
raise NotImplementedError() raise NotImplementedError()
# params is only used here for now # params is only used here for now
axis_code = """ axis_code = f"""
axis = {params}->c_axis; axis = {params}->c_axis;
if(axis > PyArray_NDIM({x})-1 || axis < -PyArray_NDIM({x})){{ if(axis > PyArray_NDIM({x})-1 || axis < -PyArray_NDIM({x})){{
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument"); "Argmax, bad axis argument");
{fail} {fail}
}} }}
""".format(**locals()) """
ret = """ ret = """
int axis; int axis;
...@@ -1314,7 +1314,8 @@ class Mean(FixedOpCAReduce): ...@@ -1314,7 +1314,8 @@ class Mean(FixedOpCAReduce):
def __str__(self): def __str__(self):
if self.axis is not None: if self.axis is not None:
return "Mean{{{}}}".format(", ".join(str(x) for x in self.axis)) args = ", ".join(str(x) for x in self.axis)
return f"Mean{{{args}}}"
else: else:
return "Mean" return "Mean"
......
...@@ -1137,7 +1137,7 @@ class Subtensor(COp): ...@@ -1137,7 +1137,7 @@ class Subtensor(COp):
""" """
rval += """ rval += f"""
// One more argument of the view // One more argument of the view
npy_intp xview_offset = 0; npy_intp xview_offset = 0;
...@@ -1264,7 +1264,7 @@ class Subtensor(COp): ...@@ -1264,7 +1264,7 @@ class Subtensor(COp):
inner_ii += 1; inner_ii += 1;
outer_ii += 1; outer_ii += 1;
}} }}
""".format(**locals()) """
# print rval # print rval
return rval return rval
...@@ -1284,19 +1284,19 @@ class Subtensor(COp): ...@@ -1284,19 +1284,19 @@ class Subtensor(COp):
decl = "PyArrayObject * xview = NULL;" decl = "PyArrayObject * xview = NULL;"
checkNDim = """ checkNDim = f"""
if (PyArray_NDIM({x}) != {ndim}){{ if (PyArray_NDIM({x}) != {ndim}){{
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"Expected {ndim} dimensions input" "Expected {ndim} dimensions input"
); );
{fail} {fail}
}} }}
""".format(**locals()) """
get_xview = self.helper_c_code( get_xview = self.helper_c_code(
node, name, inputs, outputs, sub, self.idx_list, view_ndim node, name, inputs, outputs, sub, self.idx_list, view_ndim
) )
build_view = """ build_view = f"""
//TODO: give this Op a second output so that this view can be cached //TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure //TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR({x})); Py_INCREF(PyArray_DESCR({x}));
...@@ -1314,7 +1314,7 @@ class Subtensor(COp): ...@@ -1314,7 +1314,7 @@ class Subtensor(COp):
{{ {{
{fail}; {fail};
}} }}
""".format(**locals()) """
finish_view = f""" finish_view = f"""
Py_XDECREF({z}); Py_XDECREF({z});
...@@ -1761,7 +1761,7 @@ class IncSubtensor(COp): ...@@ -1761,7 +1761,7 @@ class IncSubtensor(COp):
copy_of_x = self.copy_of_x(x) copy_of_x = self.copy_of_x(x)
copy_input_if_necessary = """ copy_input_if_necessary = f"""
if ({inplace}) if ({inplace})
{{ {{
if ({x} != {z}) if ({x} != {z})
...@@ -1780,7 +1780,7 @@ class IncSubtensor(COp): ...@@ -1780,7 +1780,7 @@ class IncSubtensor(COp):
{fail} {fail}
}} }}
}} }}
""".format(**locals()) """
# get info needed to make zview: a view of %(z)s # get info needed to make zview: a view of %(z)s
helper_args = self.get_helper_c_code_args() helper_args = self.get_helper_c_code_args()
...@@ -1799,7 +1799,7 @@ class IncSubtensor(COp): ...@@ -1799,7 +1799,7 @@ class IncSubtensor(COp):
# Make a view on the output, as we will write into it. # Make a view on the output, as we will write into it.
alloc_zview = self.make_view_array(z, view_ndim) alloc_zview = self.make_view_array(z, view_ndim)
build_view = """ build_view = f"""
//TODO: give this Op a second output so that this view can be cached //TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure //TODO: alternatively, fix the memory leak on failure
{alloc_zview}; {alloc_zview};
...@@ -1807,13 +1807,13 @@ class IncSubtensor(COp): ...@@ -1807,13 +1807,13 @@ class IncSubtensor(COp):
{{ {{
{fail}; {fail};
}} }}
""".format(**locals()) """
copy_into = self.copy_into("zview", y) copy_into = self.copy_into("zview", y)
add_to_zview = self.add_to_zview(name, y, fail) add_to_zview = self.add_to_zview(name, y, fail)
make_modification = """ make_modification = f"""
if ({op_is_set}) if ({op_is_set})
{{ {{
if ({copy_into}) // does broadcasting if ({copy_into}) // does broadcasting
...@@ -1826,7 +1826,7 @@ class IncSubtensor(COp): ...@@ -1826,7 +1826,7 @@ class IncSubtensor(COp):
{{ {{
{add_to_zview} {add_to_zview}
}} }}
""".format(**locals()) """
return ( return (
self.decl_view() self.decl_view()
+ copy_input_if_necessary + copy_input_if_necessary
...@@ -1896,7 +1896,7 @@ class IncSubtensor(COp): ...@@ -1896,7 +1896,7 @@ class IncSubtensor(COp):
""" """
return """Py_INCREF(PyArray_DESCR({x})); return f"""Py_INCREF(PyArray_DESCR({x}));
zview = (PyArrayObject*)PyArray_NewFromDescr( zview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type, &PyArray_Type,
PyArray_DESCR({x}), PyArray_DESCR({x}),
...@@ -1906,7 +1906,7 @@ class IncSubtensor(COp): ...@@ -1906,7 +1906,7 @@ class IncSubtensor(COp):
PyArray_BYTES({x}) + xview_offset, //PyArray_DATA({x}), PyArray_BYTES({x}) + xview_offset, //PyArray_DATA({x}),
PyArray_FLAGS({x}), PyArray_FLAGS({x}),
NULL); NULL);
""".format(**locals()) """
def get_helper_c_code_args(self): def get_helper_c_code_args(self):
""" """
...@@ -1939,7 +1939,7 @@ class IncSubtensor(COp): ...@@ -1939,7 +1939,7 @@ class IncSubtensor(COp):
""" """
return """ return f"""
PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd( PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd(
(PyObject*)zview, py_{x}); (PyObject*)zview, py_{x});
if (add_rval) if (add_rval)
...@@ -1952,7 +1952,7 @@ class IncSubtensor(COp): ...@@ -1952,7 +1952,7 @@ class IncSubtensor(COp):
{{ {{
Py_DECREF(zview); Py_DECREF(zview);
{fail}; {fail};
}}""".format(**locals()) }}"""
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
return [shapes[0]] return [shapes[0]]
...@@ -2162,7 +2162,7 @@ class AdvancedSubtensor1(COp): ...@@ -2162,7 +2162,7 @@ class AdvancedSubtensor1(COp):
a_name, i_name = input_names[0], input_names[1] a_name, i_name = input_names[0], input_names[1]
output_name = output_names[0] output_name = output_names[0]
fail = sub["fail"] fail = sub["fail"]
return """ return f"""
PyArrayObject *indices; PyArrayObject *indices;
int i_type = PyArray_TYPE({i_name}); int i_type = PyArray_TYPE({i_name});
if (i_type != NPY_INTP) {{ if (i_type != NPY_INTP) {{
...@@ -2237,7 +2237,7 @@ class AdvancedSubtensor1(COp): ...@@ -2237,7 +2237,7 @@ class AdvancedSubtensor1(COp):
{a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE); {a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE);
Py_DECREF(indices); Py_DECREF(indices);
if ({output_name} == NULL) {fail}; if ({output_name} == NULL) {fail};
""".format(**locals()) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, 1, 2) return (0, 1, 2)
...@@ -2523,8 +2523,10 @@ class AdvancedIncSubtensor1(COp): ...@@ -2523,8 +2523,10 @@ class AdvancedIncSubtensor1(COp):
x, y, idx = input_names x, y, idx = input_names
out = output_names[0] out = output_names[0]
copy_of_x = self.copy_of_x(x) copy_of_x = self.copy_of_x(x)
params = sub["params"]
fail = sub["fail"]
return """ return f"""
PyObject* rval = NULL; PyObject* rval = NULL;
if ({params}->inplace) if ({params}->inplace)
{{ {{
...@@ -2548,17 +2550,7 @@ class AdvancedIncSubtensor1(COp): ...@@ -2548,17 +2550,7 @@ class AdvancedIncSubtensor1(COp):
{fail}; {fail};
}} }}
Py_XDECREF(rval); Py_XDECREF(rval);
""".format( """
**dict(
x=x,
y=y,
idx=idx,
out=out,
copy_of_x=copy_of_x,
params=sub["params"],
fail=sub["fail"],
)
)
def c_code_cache_version(self): def c_code_cache_version(self):
return (8,) return (8,)
......
...@@ -475,25 +475,28 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -475,25 +475,28 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
if check_input: if check_input:
check = """ dtype = self.dtype_specs()[1]
check = f"""
typedef {dtype} dtype_{name}; typedef {dtype} dtype_{name};
""".format(**dict(sub, name=name, dtype=self.dtype_specs()[1])) """
else: else:
check = "" check = ""
declaration = """ declaration = f"""
PyArrayObject* {name}; PyArrayObject* {name};
""".format(**dict(sub, name=name, dtype=self.dtype_specs()[1])) """
return declaration + check return declaration + check
def c_init(self, name, sub): def c_init(self, name, sub):
return """ return f"""
{name} = NULL; {name} = NULL;
""".format(**dict(sub, name=name, type_num=self.dtype_specs()[2])) """
def c_extract(self, name, sub, check_input=True, **kwargs): def c_extract(self, name, sub, check_input=True, **kwargs):
if check_input: if check_input:
check = """ fail = sub["fail"]
type_num = self.dtype_specs()[2]
check = f"""
{name} = NULL; {name} = NULL;
if (py_{name} == Py_None) {{ if (py_{name} == Py_None) {{
// We can either fail here or set {name} to NULL and rely on Ops // We can either fail here or set {name} to NULL and rely on Ops
...@@ -541,28 +544,27 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -541,28 +544,27 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
{type_num}, PyArray_TYPE((PyArrayObject*) py_{name})); {type_num}, PyArray_TYPE((PyArrayObject*) py_{name}));
{fail} {fail}
}} }}
""".format(**dict(sub, name=name, type_num=self.dtype_specs()[2])) """
else: else:
check = "" check = ""
return ( return (
check check
+ """ + f"""
{name} = (PyArrayObject*)(py_{name}); {name} = (PyArrayObject*)(py_{name});
Py_XINCREF({name}); Py_XINCREF({name});
""".format(**dict(sub, name=name, type_num=self.dtype_specs()[2])) """
) )
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
return """ return f"""
if ({name}) {{ if ({name}) {{
Py_XDECREF({name}); Py_XDECREF({name});
}} }}
""".format(**locals()) """
def c_sync(self, name, sub): def c_sync(self, name, sub):
fail = sub["fail"] fail = sub["fail"]
type_num = self.dtype_specs()[2] return f"""
return """
{{Py_XDECREF(py_{name});}} {{Py_XDECREF(py_{name});}}
if (!{name}) {{ if (!{name}) {{
Py_INCREF(Py_None); Py_INCREF(Py_None);
...@@ -597,7 +599,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -597,7 +599,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
); );
{fail} {fail}
}} }}
""".format(**locals()) """
def c_headers(self, **kwargs): def c_headers(self, **kwargs):
return ps.get_scalar_type(self.dtype).c_headers(**kwargs) return ps.get_scalar_type(self.dtype).c_headers(**kwargs)
......
...@@ -102,13 +102,13 @@ class GetItem(COp): ...@@ -102,13 +102,13 @@ class GetItem(COp):
x_name, index = inp[0], inp[1] x_name, index = inp[0], inp[1]
output_name = out[0] output_name = out[0]
fail = sub["fail"] fail = sub["fail"]
return """ return f"""
{output_name} = (typeof {output_name}) PyList_GetItem( (PyObject*) {x_name}, *((npy_int64 *) PyArray_DATA({index}))); {output_name} = (typeof {output_name}) PyList_GetItem( (PyObject*) {x_name}, *((npy_int64 *) PyArray_DATA({index})));
if({output_name} == NULL){{ if({output_name} == NULL){{
{fail} {fail}
}} }}
Py_INCREF({output_name}); Py_INCREF({output_name});
""".format(**locals()) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
...@@ -169,16 +169,16 @@ class Append(COp): ...@@ -169,16 +169,16 @@ class Append(COp):
output_name = out[0] output_name = out[0]
fail = sub["fail"] fail = sub["fail"]
if not self.inplace: if not self.inplace:
init = """ init = f"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ; {output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals()) """
else: else:
init = f""" init = f"""
{output_name} = {x_name}; {output_name} = {x_name};
""" """
return ( return (
init init
+ """ + f"""
if({output_name}==NULL){{ if({output_name}==NULL){{
{fail} {fail}
}}; }};
...@@ -186,7 +186,7 @@ class Append(COp): ...@@ -186,7 +186,7 @@ class Append(COp):
{fail} {fail}
}}; }};
Py_INCREF({output_name}); Py_INCREF({output_name});
""".format(**locals()) """
) )
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -249,16 +249,16 @@ class Extend(COp): ...@@ -249,16 +249,16 @@ class Extend(COp):
output_name = out[0] output_name = out[0]
fail = sub["fail"] fail = sub["fail"]
if not self.inplace: if not self.inplace:
init = """ init = f"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ; {output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals()) """
else: else:
init = f""" init = f"""
{output_name} = {x_name}; {output_name} = {x_name};
""" """
return ( return (
init init
+ """ + f"""
int i =0; int i =0;
int length = PyList_GET_SIZE((PyObject*) {toAppend}); int length = PyList_GET_SIZE((PyObject*) {toAppend});
if({output_name}==NULL){{ if({output_name}==NULL){{
...@@ -270,7 +270,7 @@ class Extend(COp): ...@@ -270,7 +270,7 @@ class Extend(COp):
}}; }};
}} }}
Py_INCREF({output_name}); Py_INCREF({output_name});
""".format(**locals()) """
) )
def c_code_cache_version_(self): def c_code_cache_version_(self):
...@@ -337,16 +337,16 @@ class Insert(COp): ...@@ -337,16 +337,16 @@ class Insert(COp):
output_name = out[0] output_name = out[0]
fail = sub["fail"] fail = sub["fail"]
if not self.inplace: if not self.inplace:
init = """ init = f"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ; {output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals()) """
else: else:
init = f""" init = f"""
{output_name} = {x_name}; {output_name} = {x_name};
""" """
return ( return (
init init
+ """ + f"""
if({output_name}==NULL){{ if({output_name}==NULL){{
{fail} {fail}
}}; }};
...@@ -354,7 +354,7 @@ class Insert(COp): ...@@ -354,7 +354,7 @@ class Insert(COp):
{fail} {fail}
}}; }};
Py_INCREF({output_name}); Py_INCREF({output_name});
""".format(**locals()) """
) )
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -465,16 +465,16 @@ class Reverse(COp): ...@@ -465,16 +465,16 @@ class Reverse(COp):
output_name = out[0] output_name = out[0]
fail = sub["fail"] fail = sub["fail"]
if not self.inplace: if not self.inplace:
init = """ init = f"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ; {output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals()) """
else: else:
init = f""" init = f"""
{output_name} = {x_name}; {output_name} = {x_name};
""" """
return ( return (
init init
+ """ + f"""
if({output_name}==NULL){{ if({output_name}==NULL){{
{fail} {fail}
}}; }};
...@@ -482,7 +482,7 @@ class Reverse(COp): ...@@ -482,7 +482,7 @@ class Reverse(COp):
{fail} {fail}
}}; }};
Py_INCREF({output_name}); Py_INCREF({output_name});
""".format(**locals()) """
) )
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -595,13 +595,12 @@ class Length(COp): ...@@ -595,13 +595,12 @@ class Length(COp):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
x_name = inp[0] x_name = inp[0]
output_name = out[0] output_name = out[0]
fail = sub["fail"] return f"""
return """
if(!{output_name}) if(!{output_name})
{output_name}=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0); {output_name}=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA({output_name}))[0]=PyList_Size((PyObject*){x_name}); ((npy_int64*)PyArray_DATA({output_name}))[0]=PyList_Size((PyObject*){x_name});
Py_INCREF({output_name}); Py_INCREF({output_name});
""".format(**locals()) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
......
...@@ -110,26 +110,27 @@ class TypedListType(CType): ...@@ -110,26 +110,27 @@ class TypedListType(CType):
def c_extract(self, name, sub, check_input=True, **kwargs): def c_extract(self, name, sub, check_input=True, **kwargs):
if check_input: if check_input:
pre = """ fail = sub["fail"]
pre = f"""
if (!PyList_Check(py_{name})) {{ if (!PyList_Check(py_{name})) {{
PyErr_SetString(PyExc_TypeError, "expected a list"); PyErr_SetString(PyExc_TypeError, "expected a list");
{fail} {fail}
}}""".format(**dict(name=name, fail=sub["fail"])) }}"""
else: else:
pre = "" pre = ""
return ( return (
pre pre
+ """ + f"""
{name} = (PyListObject*) (py_{name}); {name} = (PyListObject*) (py_{name});
""".format(**dict(name=name, fail=sub["fail"])) """
) )
def c_sync(self, name, sub): def c_sync(self, name, sub):
return """ return f"""
Py_XDECREF(py_{name}); Py_XDECREF(py_{name});
py_{name} = (PyObject*)({name}); py_{name} = (PyObject*)({name});
Py_INCREF(py_{name}); Py_INCREF(py_{name});
""".format(**dict(name=name)) """
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
return "" return ""
......
...@@ -64,7 +64,8 @@ class BROKEN_ON_PURPOSE_Add(COp): ...@@ -64,7 +64,8 @@ class BROKEN_ON_PURPOSE_Add(COp):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
a, b = inp a, b = inp
(z,) = out (z,) = out
return """ fail = sub["fail"]
return f"""
if (PyArray_NDIM({a}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 1"); {fail};}} if (PyArray_NDIM({a}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 1"); {fail};}}
if (PyArray_NDIM({b}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); {fail};}} if (PyArray_NDIM({b}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); {fail};}}
...@@ -96,7 +97,7 @@ class BROKEN_ON_PURPOSE_Add(COp): ...@@ -96,7 +97,7 @@ class BROKEN_ON_PURPOSE_Add(COp):
+ ((double*)PyArray_GETPTR1({b}, m))[0] ; + ((double*)PyArray_GETPTR1({b}, m))[0] ;
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
# inconsistent is a invalid op, whose perform and c_code do not match # inconsistent is a invalid op, whose perform and c_code do not match
...@@ -632,7 +633,8 @@ class BrokenCImplementationAdd(COp): ...@@ -632,7 +633,8 @@ class BrokenCImplementationAdd(COp):
a, b = inp a, b = inp
(z,) = out (z,) = out
debug = 0 debug = 0
return """ fail = sub["fail"]
return f"""
//printf("executing c_code\\n"); //printf("executing c_code\\n");
if (PyArray_NDIM({a}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); {fail};}} if (PyArray_NDIM({a}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); {fail};}}
if (PyArray_NDIM({b}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); {fail};}} if (PyArray_NDIM({b}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); {fail};}}
...@@ -690,7 +692,7 @@ class BrokenCImplementationAdd(COp): ...@@ -690,7 +692,7 @@ class BrokenCImplementationAdd(COp):
}} }}
}} }}
}} }}
""".format(**dict(locals(), **sub)) """
class VecAsRowAndCol(Op): class VecAsRowAndCol(Op):
......
...@@ -39,7 +39,8 @@ class TDouble(CType): ...@@ -39,7 +39,8 @@ class TDouble(CType):
return str(data) return str(data)
def c_extract(self, name, sub, check_input=True, **kwargs): def c_extract(self, name, sub, check_input=True, **kwargs):
return """ fail = sub["fail"]
return f"""
if (!PyFloat_Check(py_{name})) {{ if (!PyFloat_Check(py_{name})) {{
PyErr_SetString(PyExc_TypeError, "not a double!"); PyErr_SetString(PyExc_TypeError, "not a double!");
{fail} {fail}
...@@ -47,7 +48,7 @@ class TDouble(CType): ...@@ -47,7 +48,7 @@ class TDouble(CType):
{name} = PyFloat_AsDouble(py_{name}); {name} = PyFloat_AsDouble(py_{name});
{name}_bad_thing = NULL; {name}_bad_thing = NULL;
//printf("Extracting {name}\\n"); //printf("Extracting {name}\\n");
""".format(**dict(locals(), **sub)) """
def c_sync(self, name, sub): def c_sync(self, name, sub):
return f""" return f"""
......
...@@ -189,12 +189,11 @@ def mock_system(request): ...@@ -189,12 +189,11 @@ def mock_system(request):
@pytest.fixture() @pytest.fixture()
def cxx_search_dirs(blas_libs, mock_system): def cxx_search_dirs(blas_libs, mock_system):
libext = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"} libext = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"}
libtemplate = f"{{lib}}.{libext[mock_system]}"
libraries = [] libraries = []
with tempfile.TemporaryDirectory() as d: with tempfile.TemporaryDirectory() as d:
flags = None flags = None
for lib in blas_libs: for lib in blas_libs:
lib_path = Path(d) / libtemplate.format(lib=lib) lib_path = Path(d) / f"{lib}.{libext[mock_system]}"
lib_path.write_bytes(b"1") lib_path.write_bytes(b"1")
libraries.append(lib_path) libraries.append(lib_path)
if flags is None: if flags is None:
...@@ -262,14 +261,13 @@ def test_default_blas_ldflags_no_cxx(): ...@@ -262,14 +261,13 @@ def test_default_blas_ldflags_no_cxx():
@pytest.fixture() @pytest.fixture()
def windows_conda_libs(blas_libs): def windows_conda_libs(blas_libs):
libtemplate = "{lib}.dll"
libraries = [] libraries = []
with tempfile.TemporaryDirectory() as d: with tempfile.TemporaryDirectory() as d:
subdir = Path(d) / "Library" / "bin" subdir = Path(d) / "Library" / "bin"
subdir.mkdir(exist_ok=True, parents=True) subdir.mkdir(exist_ok=True, parents=True)
flags = f'-L"{subdir}"' flags = f'-L"{subdir}"'
for lib in blas_libs: for lib in blas_libs:
lib_path = subdir / libtemplate.format(lib=lib) lib_path = subdir / f"{lib}.dll"
lib_path.write_bytes(b"1") lib_path.write_bytes(b"1")
libraries.append(lib_path) libraries.append(lib_path)
flags += f" -l{lib}" flags += f" -l{lib}"
......
...@@ -79,10 +79,10 @@ class StructOp(COp): ...@@ -79,10 +79,10 @@ class StructOp(COp):
return f"counter{name} = 0;" return f"counter{name} = 0;"
def c_code(self, node, name, input_names, outputs_names, sub): def c_code(self, node, name, input_names, outputs_names, sub):
return """ return f"""
{out} = counter{name}; {outputs_names[0]} = counter{name};
counter{name}++; counter{name}++;
""".format(**dict(out=outputs_names[0], name=name)) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
......
...@@ -73,30 +73,24 @@ class QuadraticOpFunc(COp): ...@@ -73,30 +73,24 @@ class QuadraticOpFunc(COp):
""" """
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
return """ X = inputs[0]
{float_type} a = ({float_type}) (*(npy_float64*) PyArray_GETPTR1({coeff}->a, 0)); // 0-D TensorType. Y = outputs[0]
{float_type} b = {coeff}->b; // ScalarType. float_type = node.inputs[0].type.c_element_type()
{float_type} c = ({float_type}) PyFloat_AsDouble({coeff}->c); // Generic. return f"""
{float_type} a = ({float_type}) (*(npy_float64*) PyArray_GETPTR1({sub['params']}->a, 0)); // 0-D TensorType.
{float_type} b = {sub['params']}->b; // ScalarType.
{float_type} c = ({float_type}) PyFloat_AsDouble({sub['params']}->c); // Generic.
Py_XDECREF({Y}); Py_XDECREF({Y});
{Y} = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM({X}), PyArray_DIMS({X}), PyArray_TYPE({X}), PyArray_IS_F_CONTIGUOUS({X})); {Y} = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM({X}), PyArray_DIMS({X}), PyArray_TYPE({X}), PyArray_IS_F_CONTIGUOUS({X}));
if (PyArray_CopyInto({Y}, {X}) != 0) {{ if (PyArray_CopyInto({Y}, {X}) != 0) {{
PyErr_SetString(PyExc_RuntimeError, "Unable to copy input into output."); PyErr_SetString(PyExc_RuntimeError, "Unable to copy input into output.");
{fail} {sub['fail']}
}}; }};
if (quadratic_{name}({Y}, a, b, c) != 0) {{ if (quadratic_{name}({Y}, a, b, c) != 0) {{
PyErr_SetString(PyExc_RuntimeError, "Unable to compute quadratic function."); PyErr_SetString(PyExc_RuntimeError, "Unable to compute quadratic function.");
{fail} {sub['fail']}
}} }}
""".format( """
**dict(
name=name,
coeff=sub["params"],
fail=sub["fail"],
X=inputs[0],
Y=outputs[0],
float_type=node.inputs[0].type.c_element_type(),
)
)
# Same op as above, but implemented as a ExternalCOp (with C code in an # Same op as above, but implemented as a ExternalCOp (with C code in an
......
...@@ -25,11 +25,12 @@ Py_XDECREF((PyObject *)p); ...@@ -25,11 +25,12 @@ Py_XDECREF((PyObject *)p);
""" """
def c_code(self, node, name, inps, outs, sub): def c_code(self, node, name, inps, outs, sub):
return """ return f"""
Py_XDECREF({out}); Py_XDECREF({outs[0]});
{out} = (void *){inp}; {outs[0]} = (void *){inps[0]};
Py_INCREF({inp}); Py_INCREF({inps[0]});
""".format(**dict(out=outs[0], inp=inps[0])) """
# FIXME: should it not be outs[0]?
def c_code_cache_version(self): def c_code_cache_version(self):
return (0,) return (0,)
...@@ -52,11 +53,11 @@ Py_XDECREF((PyObject *)p); ...@@ -52,11 +53,11 @@ Py_XDECREF((PyObject *)p);
""" """
def c_code(self, node, name, inps, outs, sub): def c_code(self, node, name, inps, outs, sub):
return """ return f"""
Py_XDECREF({out}); Py_XDECREF({outs[0]});
{out} = (PyArrayObject *){inp}; {outs[0]} = (PyArrayObject *){inps[0]};
Py_INCREF({out}); Py_INCREF({outs[0]});
""".format(**dict(out=outs[0], inp=inps[0])) """
def c_code_cache_version(self): def c_code_cache_version(self):
return (0,) return (0,)
...@@ -136,7 +137,12 @@ class MyOpEnumList(COp): ...@@ -136,7 +137,12 @@ class MyOpEnumList(COp):
return (1,) return (1,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
return """ op = sub["params"]
o = outputs[0]
a = inputs[0]
b = inputs[1]
fail = sub["fail"]
return f"""
switch({op}) {{ switch({op}) {{
case ADD: case ADD:
{o} = {a} + {b}; {o} = {a} + {b};
...@@ -154,15 +160,7 @@ class MyOpEnumList(COp): ...@@ -154,15 +160,7 @@ class MyOpEnumList(COp):
{{{fail}}} {{{fail}}}
break; break;
}} }}
""".format( """
**dict(
op=sub["params"],
o=outputs[0],
a=inputs[0],
b=inputs[1],
fail=sub["fail"],
)
)
class MyOpCEnumType(COp): class MyOpCEnumType(COp):
...@@ -201,15 +199,10 @@ class MyOpCEnumType(COp): ...@@ -201,15 +199,10 @@ class MyOpCEnumType(COp):
return (3,) return (3,)
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
return """ # params in C code will already contains expected C constant value.
{o} = {val}; return f"""
""".format( {outputs[0]} = {sub['params']};
**dict( """
o=outputs[0],
# params in C code will already contains expected C constant value.
val=sub["params"],
)
)
class TestEnumTypes: class TestEnumTypes:
......
...@@ -323,7 +323,9 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp): ...@@ -323,7 +323,9 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
) )
depth = "-1" depth = "-1"
return """ fail = sub["fail"]
params = sub["params"]
return f"""
// Mandatory args // Mandatory args
int direction = {params}->direction; // forward, bprop weights, bprop inputs int direction = {params}->direction; // forward, bprop weights, bprop inputs
...@@ -568,18 +570,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp): ...@@ -568,18 +570,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
}} }}
assert (out2 == *out); assert (out2 == *out);
""".format( """
**dict(
bottom=bottom,
weights=weights,
top=top,
height=height,
width=width,
depth=depth,
fail=sub["fail"],
params=sub["params"],
)
)
class Corr3dMM(BaseCorr3dMM): class Corr3dMM(BaseCorr3dMM):
......
...@@ -322,7 +322,9 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp): ...@@ -322,7 +322,9 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
) )
width = "-1" width = "-1"
return """ fail = sub["fail"]
params = sub["params"]
return f"""
// Mandatory args // Mandatory args
int direction = {params}->direction; // forward, bprop weights, bprop inputs int direction = {params}->direction; // forward, bprop weights, bprop inputs
...@@ -614,17 +616,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp): ...@@ -614,17 +616,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
}} }}
assert (out2 == *out); assert (out2 == *out);
""".format( """
**dict(
bottom=bottom,
weights=weights,
top=top,
height=height,
width=width,
fail=sub["fail"],
params=sub["params"],
)
)
class CorrMM(BaseCorrMM): class CorrMM(BaseCorrMM):
......
...@@ -1391,9 +1391,9 @@ class TimesN(ps.basic.UnaryScalarOp): ...@@ -1391,9 +1391,9 @@ class TimesN(ps.basic.UnaryScalarOp):
def c_support_code_apply(self, node, nodename): def c_support_code_apply(self, node, nodename):
n = str(self.n) n = str(self.n)
return """ return f"""
float {nodename}_timesn(float x) {{ return x * {n}; }} float {nodename}_timesn(float x) {{ return x * {n}; }}
""".format(**locals()) """
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs (x,) = inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论