提交 06c5acdf authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Fix UP031 with `ruff --unsafe-fixes`

上级 cada5ad2
...@@ -130,7 +130,7 @@ disable = ["C0330", "C0326"] ...@@ -130,7 +130,7 @@ disable = ["C0330", "C0326"]
[tool.ruff] [tool.ruff]
select = ["C", "E", "F", "I", "UP", "W"] select = ["C", "E", "F", "I", "UP", "W"]
ignore = ["C408", "C901", "E501", "E741", "UP031"] ignore = ["C408", "C901", "E501", "E741"]
exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"] exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"]
......
...@@ -465,7 +465,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -465,7 +465,7 @@ class OpFromGraph(Op, HasInnerGraph):
def __str__(self): def __str__(self):
name = self.__class__.__name__ if self.name is None else self.name name = self.__class__.__name__ if self.name is None else self.name
is_inline = self.is_inline is_inline = self.is_inline
return "%(name)s{inline=%(is_inline)s}" % locals() return "{name}{{inline={is_inline}}}".format(**locals())
@config.change_flags(compute_test_value="off") @config.change_flags(compute_test_value="off")
def _recompute_lop_op(self): def _recompute_lop_op(self):
......
...@@ -250,33 +250,30 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -250,33 +250,30 @@ def struct_gen(args, struct_builders, blocks, sub):
# that holds the type, the value and the traceback. After storing # that holds the type, the value and the traceback. After storing
# the error, we return the failure code so we know which code # the error, we return the failure code so we know which code
# block failed. # block failed.
do_return = ( do_return = """
""" if ({failure_var}) {{
if (%(failure_var)s) {
// When there is a failure, this code puts the exception // When there is a failure, this code puts the exception
// in __ERROR. // in __ERROR.
PyObject* err_type = NULL; PyObject* err_type = NULL;
PyObject* err_msg = NULL; PyObject* err_msg = NULL;
PyObject* err_traceback = NULL; PyObject* err_traceback = NULL;
PyErr_Fetch(&err_type, &err_msg, &err_traceback); PyErr_Fetch(&err_type, &err_msg, &err_traceback);
if (!err_type) {err_type = Py_None;Py_INCREF(Py_None);} if (!err_type) {{err_type = Py_None;Py_INCREF(Py_None);}}
if (!err_msg) {err_msg = Py_None; Py_INCREF(Py_None);} if (!err_msg) {{err_msg = Py_None; Py_INCREF(Py_None);}}
if (!err_traceback) {err_traceback = Py_None; Py_INCREF(Py_None);} if (!err_traceback) {{err_traceback = Py_None; Py_INCREF(Py_None);}}
PyObject* old_err_type = PyList_GET_ITEM(__ERROR, 0); PyObject* old_err_type = PyList_GET_ITEM(__ERROR, 0);
PyObject* old_err_msg = PyList_GET_ITEM(__ERROR, 1); PyObject* old_err_msg = PyList_GET_ITEM(__ERROR, 1);
PyObject* old_err_traceback = PyList_GET_ITEM(__ERROR, 2); PyObject* old_err_traceback = PyList_GET_ITEM(__ERROR, 2);
PyList_SET_ITEM(__ERROR, 0, err_type); PyList_SET_ITEM(__ERROR, 0, err_type);
PyList_SET_ITEM(__ERROR, 1, err_msg); PyList_SET_ITEM(__ERROR, 1, err_msg);
PyList_SET_ITEM(__ERROR, 2, err_traceback); PyList_SET_ITEM(__ERROR, 2, err_traceback);
{Py_XDECREF(old_err_type);} {{Py_XDECREF(old_err_type);}}
{Py_XDECREF(old_err_msg);} {{Py_XDECREF(old_err_msg);}}
{Py_XDECREF(old_err_traceback);} {{Py_XDECREF(old_err_traceback);}}
} }}
// The failure code is returned to index what code block failed. // The failure code is returned to index what code block failed.
return %(failure_var)s; return {failure_var};
""" """.format(**sub)
% sub
)
sub = dict(sub) sub = dict(sub)
sub.update(locals()) sub.update(locals())
...@@ -284,16 +281,15 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -284,16 +281,15 @@ def struct_gen(args, struct_builders, blocks, sub):
# TODO: add some error checking to make sure storage_<x> are # TODO: add some error checking to make sure storage_<x> are
# 1-element lists and __ERROR is a 3-elements list. # 1-element lists and __ERROR is a 3-elements list.
struct_code = ( struct_code = """
""" namespace {{
namespace { struct {name} {{
struct %(name)s {
PyObject* __ERROR; PyObject* __ERROR;
%(storage_decl)s {storage_decl}
%(struct_decl)s {struct_decl}
%(name)s() { {name}() {{
// This is only somewhat safe because we: // This is only somewhat safe because we:
// 1) Are not a virtual class // 1) Are not a virtual class
// 2) Do not use any virtual classes in the members // 2) Do not use any virtual classes in the members
...@@ -306,32 +302,30 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -306,32 +302,30 @@ def struct_gen(args, struct_builders, blocks, sub):
#ifndef PYTENSOR_DONT_MEMSET_STRUCT #ifndef PYTENSOR_DONT_MEMSET_STRUCT
memset(this, 0, sizeof(*this)); memset(this, 0, sizeof(*this));
#endif #endif
} }}
~%(name)s(void) { ~{name}(void) {{
cleanup(); cleanup();
} }}
int init(PyObject* __ERROR, %(args_decl)s) { int init(PyObject* __ERROR, {args_decl}) {{
%(storage_incref)s {storage_incref}
%(storage_set)s {storage_set}
%(struct_init_head)s {struct_init_head}
this->__ERROR = __ERROR; this->__ERROR = __ERROR;
return 0; return 0;
} }}
void cleanup(void) { void cleanup(void) {{
%(struct_cleanup)s {struct_cleanup}
%(storage_decref)s {storage_decref}
} }}
int run(void) { int run(void) {{
int %(failure_var)s = 0; int {failure_var} = 0;
%(behavior)s {behavior}
%(do_return)s {do_return}
} }}
}; }};
} }}
""" """.format(**sub)
% sub
)
return struct_code return struct_code
...@@ -380,9 +374,9 @@ def get_c_init(fgraph, r, name, sub): ...@@ -380,9 +374,9 @@ def get_c_init(fgraph, r, name, sub):
pre = ( pre = (
"" ""
""" """
py_%(name)s = Py_None; py_{name} = Py_None;
{Py_XINCREF(py_%(name)s);} {{Py_XINCREF(py_{name});}}
""" % locals() """.format(**locals())
) )
return pre + r.type.c_init(name, sub) return pre + r.type.c_init(name, sub)
...@@ -418,9 +412,9 @@ def get_c_extract(fgraph, r, name, sub): ...@@ -418,9 +412,9 @@ def get_c_extract(fgraph, r, name, sub):
c_extract = r.type.c_extract(name, sub, False) c_extract = r.type.c_extract(name, sub, False)
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{Py_XINCREF(py_%(name)s);} {{Py_XINCREF(py_{name});}}
""" % locals() """.format(**locals())
return pre + c_extract return pre + c_extract
...@@ -447,9 +441,9 @@ def get_c_extract_out(fgraph, r, name, sub): ...@@ -447,9 +441,9 @@ def get_c_extract_out(fgraph, r, name, sub):
c_extract = r.type.c_extract_out(name, sub, check_input, check_broadcast=False) c_extract = r.type.c_extract_out(name, sub, check_input, check_broadcast=False)
pre = """ pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0); py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{Py_XINCREF(py_%(name)s);} {{Py_XINCREF(py_{name});}}
""" % locals() """.format(**locals())
return pre + c_extract return pre + c_extract
...@@ -459,8 +453,8 @@ def get_c_cleanup(fgraph, r, name, sub): ...@@ -459,8 +453,8 @@ def get_c_cleanup(fgraph, r, name, sub):
""" """
post = """ post = """
{Py_XDECREF(py_%(name)s);} {{Py_XDECREF(py_{name});}}
""" % locals() """.format(**locals())
return r.type.c_cleanup(name, sub) + post return r.type.c_cleanup(name, sub) + post
...@@ -470,14 +464,14 @@ def get_c_sync(fgraph, r, name, sub): ...@@ -470,14 +464,14 @@ def get_c_sync(fgraph, r, name, sub):
""" """
return """ return """
if (!%(failure_var)s) { if (!{failure_var}) {{
%(sync)s {sync}
PyObject* old = PyList_GET_ITEM(storage_%(name)s, 0); PyObject* old = PyList_GET_ITEM(storage_{name}, 0);
{Py_XINCREF(py_%(name)s);} {{Py_XINCREF(py_{name});}}
PyList_SET_ITEM(storage_%(name)s, 0, py_%(name)s); PyList_SET_ITEM(storage_{name}, 0, py_{name});
{Py_XDECREF(old);} {{Py_XDECREF(old);}}
} }}
""" % dict(sync=r.type.c_sync(name, sub), name=name, **sub) """.format(**dict(sync=r.type.c_sync(name, sub), name=name, **sub))
def apply_policy(fgraph, policy, r, name, sub): def apply_policy(fgraph, policy, r, name, sub):
......
...@@ -1950,14 +1950,13 @@ class Compiler: ...@@ -1950,14 +1950,13 @@ class Compiler:
code = ( code = (
""" """
%(preamble)s {preamble}
int main(int argc, char** argv) int main(int argc, char** argv)
{ {{
%(body)s {body}
return 0; return 0;
} }}
""" """.format(**locals())
% locals()
).encode() ).encode()
return cls._try_compile_tmp( return cls._try_compile_tmp(
code, code,
......
...@@ -558,19 +558,21 @@ class CLinkerType(CLinkerObject): ...@@ -558,19 +558,21 @@ class CLinkerType(CLinkerObject):
""" """
return """ return """
if (py_%(name)s == Py_None) if (py_{name} == Py_None)
{ {{
%(c_init_code)s {c_init_code}
} }}
else else
{ {{
%(c_extract_code)s {c_extract_code}
} }}
""" % dict( """.format(
**dict(
name=name, name=name,
c_init_code=self.c_init(name, sub), c_init_code=self.c_init(name, sub),
c_extract_code=self.c_extract(name, sub, check_input), c_extract_code=self.c_extract(name, sub, check_input),
) )
)
def c_cleanup(self, name: str, sub: dict[str, str]) -> str: def c_cleanup(self, name: str, sub: dict[str, str]) -> str:
"""Return C code to clean up after :meth:`CLinkerType.c_extract`. """Return C code to clean up after :meth:`CLinkerType.c_extract`.
......
...@@ -596,14 +596,15 @@ class ExternalCOp(COp): ...@@ -596,14 +596,15 @@ class ExternalCOp(COp):
# Generate the C code # Generate the C code
return """ return """
%(define_macros)s {define_macros}
{ {{
if (%(func_name)s(%(func_args)s%(params)s) != 0) { if ({func_name}({func_args}{params}) != 0) {{
%(fail)s {fail}
} }}
} }}
%(undef_macros)s {undef_macros}
""" % dict( """.format(
**dict(
func_name=self.func_name, func_name=self.func_name,
fail=sub["fail"], fail=sub["fail"],
params=params, params=params,
...@@ -611,6 +612,7 @@ class ExternalCOp(COp): ...@@ -611,6 +612,7 @@ class ExternalCOp(COp):
define_macros=define_macros, define_macros=define_macros,
undef_macros=undef_macros, undef_macros=undef_macros,
) )
)
else: else:
if "code" in self.code_sections: if "code" in self.code_sections:
op_code = self.code_sections["code"] op_code = self.code_sections["code"]
......
...@@ -359,8 +359,7 @@ class ParamsType(CType): ...@@ -359,8 +359,7 @@ class ParamsType(CType):
type_name = type_instance.__class__.__name__ type_name = type_instance.__class__.__name__
if not isinstance(type_instance, CType): if not isinstance(type_instance, CType):
raise TypeError( raise TypeError(
'ParamsType: attribute "%s" should inherit from PyTensor CType, got "%s".' f'ParamsType: attribute "{attribute_name}" should inherit from PyTensor CType, got "{type_name}".'
% (attribute_name, type_name)
) )
self.length = len(kwargs) self.length = len(kwargs)
...@@ -723,15 +722,11 @@ class ParamsType(CType): ...@@ -723,15 +722,11 @@ class ParamsType(CType):
c_cleanup_list.append(type_instance.c_cleanup(attribute_name, sub)) c_cleanup_list.append(type_instance.c_cleanup(attribute_name, sub))
c_extract_list.append( c_extract_list.append(
f"""
void extract_{attribute_name}(PyObject* py_{attribute_name}) {{
{type_instance.c_extract(attribute_name, sub)}
}}
""" """
void extract_%(attribute_name)s(PyObject* py_%(attribute_name)s) {
%(extract_code)s
}
"""
% {
"attribute_name": attribute_name,
"extract_code": type_instance.c_extract(attribute_name, sub),
}
) )
struct_declare = "\n".join(c_declare_list) struct_declare = "\n".join(c_declare_list)
...@@ -759,48 +754,49 @@ class ParamsType(CType): ...@@ -759,48 +754,49 @@ class ParamsType(CType):
) )
) )
final_struct_code = """ final_struct_code = """
/** ParamsType %(struct_name)s **/ /** ParamsType {struct_name} **/
#ifndef %(struct_name_defined)s #ifndef {struct_name_defined}
#define %(struct_name_defined)s #define {struct_name_defined}
struct %(struct_name)s { struct {struct_name} {{
/* Attributes, */ /* Attributes, */
int %(struct_name)s_error; int {struct_name}_error;
%(struct_declare)s {struct_declare}
/* Constructor. */ /* Constructor. */
%(struct_name)s() { {struct_name}() {{
%(struct_name)s_error = 0; {struct_name}_error = 0;
%(struct_init)s {struct_init}
} }}
/* Destructor. */ /* Destructor. */
~%(struct_name)s() { ~{struct_name}() {{
// cleanup() is defined below. // cleanup() is defined below.
cleanup(); cleanup();
} }}
/* Cleanup method. */ /* Cleanup method. */
void cleanup() { void cleanup() {{
%(struct_cleanup)s {struct_cleanup}
} }}
/* Extraction methods. */ /* Extraction methods. */
%(struct_extract)s {struct_extract}
/* Extract method. */ /* Extract method. */
%(struct_extract_method)s {struct_extract_method}
/* Other methods. */ /* Other methods. */
void setErrorOccurred() { void setErrorOccurred() {{
++%(struct_name)s_error; ++{struct_name}_error;
} }}
int errorOccurred() { int errorOccurred() {{
return %(struct_name)s_error; return {struct_name}_error;
} }}
}; }};
#endif #endif
/** End ParamsType %(struct_name)s **/ /** End ParamsType {struct_name} **/
""" % dict( """.format(
**dict(
struct_name_defined=struct_name_defined, struct_name_defined=struct_name_defined,
struct_name=struct_name, struct_name=struct_name,
struct_declare=struct_declare, struct_declare=struct_declare,
...@@ -809,6 +805,7 @@ class ParamsType(CType): ...@@ -809,6 +805,7 @@ class ParamsType(CType):
struct_extract=struct_extract, struct_extract=struct_extract,
struct_extract_method=struct_extract_method, struct_extract_method=struct_extract_method,
) )
)
return sorted(c_support_code_set) + [final_struct_code] return sorted(c_support_code_set) + [final_struct_code]
...@@ -822,8 +819,8 @@ class ParamsType(CType): ...@@ -822,8 +819,8 @@ class ParamsType(CType):
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
return """ return """
%(struct_name)s* %(name)s; {struct_name}* {name};
""" % dict(struct_name=self.name, name=name) """.format(**dict(struct_name=self.name, name=name))
def c_init(self, name, sub): def c_init(self, name, sub):
# NB: It seems c_init() is not called for an op param. # NB: It seems c_init() is not called for an op param.
...@@ -841,36 +838,38 @@ class ParamsType(CType): ...@@ -841,36 +838,38 @@ class ParamsType(CType):
def c_extract(self, name, sub, check_input=True, **kwargs): def c_extract(self, name, sub, check_input=True, **kwargs):
return """ return """
/* Seems c_init() is not called for a op param. So I call `new` here. */ /* Seems c_init() is not called for a op param. So I call `new` here. */
%(name)s = new %(struct_name)s; {name} = new {struct_name};
{ // This need a separate namespace for Clinker {{ // This need a separate namespace for Clinker
const char* fields[] = {%(fields_list)s}; const char* fields[] = {{{fields_list}}};
if (py_%(name)s == Py_None) { if (py_{name} == Py_None) {{
PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None."); PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None.");
%(fail)s {fail}
} }}
for (int i = 0; i < %(length)s; ++i) { for (int i = 0; i < {length}; ++i) {{
PyObject* o = PyDict_GetItemString(py_%(name)s, fields[i]); PyObject* o = PyDict_GetItemString(py_{name}, fields[i]);
if (o == NULL) { if (o == NULL) {{
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute \\"%%s\\" in object.", fields[i]); PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute \\"%s\\" in object.", fields[i]);
%(fail)s {fail}
} }}
%(name)s->extract(o, i); {name}->extract(o, i);
if (%(name)s->errorOccurred()) { if ({name}->errorOccurred()) {{
/* The extract code from attribute type should have already raised a Python exception, /* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */ * so we just print the attribute name in stderr. */
fprintf(stderr, "\\nParamsType: error when extracting value for attribute \\"%%s\\".\\n", fields[i]); fprintf(stderr, "\\nParamsType: error when extracting value for attribute \\"%s\\".\\n", fields[i]);
%(fail)s {fail}
} }}
} }}
} }}
""" % dict( """.format(
**dict(
name=name, name=name,
struct_name=self.name, struct_name=self.name,
length=self.length, length=self.length,
fail=sub["fail"], fail=sub["fail"],
fields_list='"%s"' % '", "'.join(self.fields), fields_list='"%s"' % '", "'.join(self.fields),
) )
)
def c_sync(self, name, sub): def c_sync(self, name, sub):
# FIXME: Looks like we need to decrement a reference count our two. # FIXME: Looks like we need to decrement a reference count our two.
......
...@@ -99,11 +99,11 @@ class Generic(CType, Singleton): ...@@ -99,11 +99,11 @@ class Generic(CType, Singleton):
def c_sync(self, name, sub): def c_sync(self, name, sub):
return """ return """
assert(py_%(name)s->ob_refcnt > 1); assert(py_{name}->ob_refcnt > 1);
Py_DECREF(py_%(name)s); Py_DECREF(py_{name});
py_%(name)s = %(name)s ? %(name)s : Py_None; py_{name} = {name} ? {name} : Py_None;
Py_INCREF(py_%(name)s); Py_INCREF(py_{name});
""" % locals() """.format(**locals())
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
...@@ -191,17 +191,17 @@ class CDataType(CType[D]): ...@@ -191,17 +191,17 @@ class CDataType(CType[D]):
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
return """ return """
%(ctype)s %(name)s; {ctype} {name};
""" % dict(ctype=self.ctype, name=name) """.format(**dict(ctype=self.ctype, name=name))
def c_init(self, name, sub): def c_init(self, name, sub):
return f"{name} = NULL;" return f"{name} = NULL;"
def c_extract(self, name, sub, check_input=True, **kwargs): def c_extract(self, name, sub, check_input=True, **kwargs):
return """ return """
%(name)s = (%(ctype)s)PyCapsule_GetPointer(py_%(name)s, NULL); {name} = ({ctype})PyCapsule_GetPointer(py_{name}, NULL);
if (%(name)s == NULL) %(fail)s if ({name} == NULL) {fail}
""" % dict(name=name, ctype=self.ctype, fail=sub["fail"]) """.format(**dict(name=name, ctype=self.ctype, fail=sub["fail"]))
def c_sync(self, name, sub): def c_sync(self, name, sub):
freefunc = self.freefunc freefunc = self.freefunc
...@@ -576,39 +576,39 @@ class EnumType(CType, dict): ...@@ -576,39 +576,39 @@ class EnumType(CType, dict):
""" """
return """ return """
#ifdef DEBUG #ifdef DEBUG
int pytensor_enum_to_string_%(cname)s(%(ctype)s in, char* out) { int pytensor_enum_to_string_{cname}({ctype} in, char* out) {{
int ret = 0; int ret = 0;
switch(in) { switch(in) {{
%(cases)s {cases}
default: default:
PyErr_SetString(PyExc_ValueError, "%(classname)s: unknown enum value."); PyErr_SetString(PyExc_ValueError, "{classname}: unknown enum value.");
ret = -1; ret = -1;
break; break;
} }}
return ret; return ret;
} }}
#endif #endif
""" % dict( """.format(
**dict(
cname=self.cname, cname=self.cname,
ctype=self.ctype, ctype=self.ctype,
classname=type(self).__name__, classname=type(self).__name__,
cases="".join( cases="".join(
""" """
case %(name)s: sprintf(out, "%(name)s"); break; case {name}: sprintf(out, "{name}"); break;
""" """.format(**dict(name=name))
% dict(name=name)
for name in self for name in self
), ),
) )
)
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
return ( return (
self.pyint_compat_code self.pyint_compat_code
+ "".join( + "".join(
f"""
#define {k} {str(self[k])}
""" """
#define %s %s
"""
% (k, str(self[k]))
for k in sorted(self.keys()) for k in sorted(self.keys())
) )
+ self.c_to_string() + self.c_to_string()
...@@ -625,15 +625,15 @@ class EnumType(CType, dict): ...@@ -625,15 +625,15 @@ class EnumType(CType, dict):
def c_extract(self, name, sub, check_input=True, **kwargs): def c_extract(self, name, sub, check_input=True, **kwargs):
return """ return """
if (PyInt_Check(py_%(name)s)) { if (PyInt_Check(py_{name})) {{
%(name)s = (%(ctype)s)PyInt_AsLong(py_%(name)s); {name} = ({ctype})PyInt_AsLong(py_{name});
} else { }} else {{
%(name)s = (%(ctype)s)PyFloat_AsDouble(py_%(name)s); {name} = ({ctype})PyFloat_AsDouble(py_{name});
} }}
if (PyErr_Occurred()) { if (PyErr_Occurred()) {{
%(fail)s {fail}
} }}
""" % dict(ctype=self.ctype, name=name, fail=sub["fail"]) """.format(**dict(ctype=self.ctype, name=name, fail=sub["fail"]))
def c_code_cache_version(self): def c_code_cache_version(self):
return (2, self.ctype, self.cname, tuple(self.items())) return (2, self.ctype, self.cname, tuple(self.items()))
...@@ -754,14 +754,15 @@ class CEnumType(EnumList): ...@@ -754,14 +754,15 @@ class CEnumType(EnumList):
# swapped_dict's keys are integers. # swapped_dict's keys are integers.
return """ return """
switch(PyInt_AsLong(py_%(name)s)) { switch(PyInt_AsLong(py_{name})) {{
%(cases)s {cases}
default: default:
PyErr_SetString(PyExc_ValueError, "CEnumType: invalid value to map to C constants."); PyErr_SetString(PyExc_ValueError, "CEnumType: invalid value to map to C constants.");
{%(fail)s} {{{fail}}}
break; break;
} }}
""" % dict( """.format(
**dict(
name=name, name=name,
cases="".join( cases="".join(
""" """
...@@ -772,6 +773,7 @@ class CEnumType(EnumList): ...@@ -772,6 +773,7 @@ class CEnumType(EnumList):
), ),
fail=sub["fail"], fail=sub["fail"],
) )
)
def c_code_cache_version(self): def c_code_cache_version(self):
return (1, super().c_code_cache_version()) return (1, super().c_code_cache_version())
...@@ -141,7 +141,7 @@ def _check_scipy_linalg_matrix(a, func_name): ...@@ -141,7 +141,7 @@ def _check_scipy_linalg_matrix(a, func_name):
if isinstance(a, types.Optional): if isinstance(a, types.Optional):
a = a.type a = a.type
if not isinstance(a, types.Array): if not isinstance(a, types.Array):
msg = "%s.%s() only supported for array types" % interp msg = "{}.{}() only supported for array types".format(*interp)
raise numba.TypingError(msg, highlighting=False) raise numba.TypingError(msg, highlighting=False)
if a.ndim not in [1, 2]: if a.ndim not in [1, 2]:
msg = "%s.%s() only supported on 1d or 2d arrays, found %s." % ( msg = "%s.%s() only supported on 1d or 2d arrays, found %s." % (
...@@ -149,7 +149,7 @@ def _check_scipy_linalg_matrix(a, func_name): ...@@ -149,7 +149,7 @@ def _check_scipy_linalg_matrix(a, func_name):
) )
raise numba.TypingError(msg, highlighting=False) raise numba.TypingError(msg, highlighting=False)
if not isinstance(a.dtype, (types.Float, types.Complex)): if not isinstance(a.dtype, (types.Float, types.Complex)):
msg = "%s.%s() only supported on " "float and complex arrays." % interp msg = "{}.{}() only supported on " "float and complex arrays.".format(*interp)
raise numba.TypingError(msg, highlighting=False) raise numba.TypingError(msg, highlighting=False)
......
...@@ -646,8 +646,7 @@ def _debugprint( ...@@ -646,8 +646,7 @@ def _debugprint(
tot_time_percent = (tot_time_dict[node] / profile.fct_call_time) * 100 tot_time_percent = (tot_time_dict[node] / profile.fct_call_time) * 100
print( print(
"%s --> %8.2es %4.1f%% %8.2es %4.1f%%" "{} --> {:8.2e}s {:4.1f}% {:8.2e}s {:4.1f}%".format(
% (
var_output, var_output,
op_time, op_time,
op_time_percent, op_time_percent,
......
...@@ -466,40 +466,44 @@ class ScalarType(CType, HasDataType, HasShape): ...@@ -466,40 +466,44 @@ class ScalarType(CType, HasDataType, HasShape):
specs = self.dtype_specs() specs = self.dtype_specs()
if check_input: if check_input:
pre = """ pre = """
if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s)) if (!PyObject_TypeCheck(py_{name}, &{pyarr_type}))
{ {{
PyErr_Format(PyExc_ValueError, PyErr_Format(PyExc_ValueError,
"Scalar check failed (%(dtype)s)"); "Scalar check failed ({dtype})");
%(fail)s {fail}
} }}
""" % dict( """.format(
sub, name=name, dtype=specs[1], pyarr_type="Py%sArrType_Type" % specs[2] **dict(
sub,
name=name,
dtype=specs[1],
pyarr_type="Py%sArrType_Type" % specs[2],
)
) )
else: else:
pre = "" pre = ""
return ( return (
pre pre
+ """ + """
PyArray_ScalarAsCtype(py_%(name)s, &%(name)s); PyArray_ScalarAsCtype(py_{name}, &{name});
""" """.format(**dict(sub, name=name))
% dict(sub, name=name)
) )
def c_sync(self, name, sub): def c_sync(self, name, sub):
specs = self.dtype_specs() specs = self.dtype_specs()
return """ return """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_{name});
py_%(name)s = PyArrayScalar_New(%(cls)s); py_{name} = PyArrayScalar_New({cls});
if (!py_%(name)s) if (!py_{name})
{ {{
Py_XINCREF(Py_None); Py_XINCREF(Py_None);
py_%(name)s = Py_None; py_{name} = Py_None;
PyErr_Format(PyExc_MemoryError, PyErr_Format(PyExc_MemoryError,
"Instantiation of new Python scalar failed (%(dtype)s)"); "Instantiation of new Python scalar failed ({dtype})");
%(fail)s {fail}
} }}
PyArrayScalar_ASSIGN(py_%(name)s, %(cls)s, %(name)s); PyArrayScalar_ASSIGN(py_{name}, {cls}, {name});
""" % dict(sub, name=name, dtype=specs[1], cls=specs[2]) """.format(**dict(sub, name=name, dtype=specs[1], cls=specs[2]))
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
return "" return ""
......
...@@ -620,8 +620,8 @@ class Chi2SF(BinaryScalarOp): ...@@ -620,8 +620,8 @@ class Chi2SF(BinaryScalarOp):
(z,) = out (z,) = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype dtype = "npy_" + node.outputs[0].dtype
return """%(z)s = return """{z} =
(%(dtype)s) 1 - GammaP(%(k)s/2., %(x)s/2.);""" % locals() ({dtype}) 1 - GammaP({k}/2., {x}/2.);""".format(**locals())
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
...@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp): ...@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp):
(z,) = out (z,) = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype dtype = "npy_" + node.outputs[0].dtype
return """%(z)s = return """{z} =
(%(dtype)s) GammaP(%(k)s, %(x)s);""" % locals() ({dtype}) GammaP({k}, {x});""".format(**locals())
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
...@@ -712,8 +712,8 @@ class GammaIncC(BinaryScalarOp): ...@@ -712,8 +712,8 @@ class GammaIncC(BinaryScalarOp):
(z,) = out (z,) = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype dtype = "npy_" + node.outputs[0].dtype
return """%(z)s = return """{z} =
(%(dtype)s) GammaQ(%(k)s, %(x)s);""" % locals() ({dtype}) GammaQ({k}, {x});""".format(**locals())
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
...@@ -1018,8 +1018,8 @@ class GammaU(BinaryScalarOp): ...@@ -1018,8 +1018,8 @@ class GammaU(BinaryScalarOp):
(z,) = out (z,) = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype dtype = "npy_" + node.outputs[0].dtype
return """%(z)s = return """{z} =
(%(dtype)s) upperGamma(%(k)s, %(x)s);""" % locals() ({dtype}) upperGamma({k}, {x});""".format(**locals())
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
...@@ -1056,8 +1056,8 @@ class GammaL(BinaryScalarOp): ...@@ -1056,8 +1056,8 @@ class GammaL(BinaryScalarOp):
(z,) = out (z,) = out
if node.inputs[0].type in float_types: if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype dtype = "npy_" + node.outputs[0].dtype
return """%(z)s = return """{z} =
(%(dtype)s) lowerGamma(%(k)s, %(x)s);""" % locals() ({dtype}) lowerGamma({k}, {x});""".format(**locals())
raise NotImplementedError("only floatingpoint is implemented") raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other): def __eq__(self, other):
......
...@@ -3364,8 +3364,7 @@ def profile_printer( ...@@ -3364,8 +3364,7 @@ def profile_printer(
total_scan_fct_time += scan_fct_time total_scan_fct_time += scan_fct_time
total_scan_op_time += scan_op_time total_scan_op_time += scan_op_time
print( print(
" %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%" " {:5.1f}s {:5.1f}s {:5.1f}s {:5.1f}% {:5.1f}%".format(
% (
v, v,
scan_fct_time, scan_fct_time,
scan_op_time, scan_op_time,
...@@ -3385,8 +3384,7 @@ def profile_printer( ...@@ -3385,8 +3384,7 @@ def profile_printer(
print(" No scan have its inner profile enabled.", file=file) print(" No scan have its inner profile enabled.", file=file)
else: else:
print( print(
"total %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%" "total {:5.1f}s {:5.1f}s {:5.1f}s {:5.1f}% {:5.1f}%".format(
% (
total_super_scan_time, total_super_scan_time,
total_scan_fct_time, total_scan_fct_time,
total_scan_op_time, total_scan_op_time,
......
差异被折叠。
差异被折叠。
...@@ -1840,19 +1840,19 @@ class BatchedDot(COp): ...@@ -1840,19 +1840,19 @@ class BatchedDot(COp):
z_shape = ", ".join(z_dims) z_shape = ", ".join(z_dims)
z_contiguous = contiguous(_z, z_ndim) z_contiguous = contiguous(_z, z_ndim)
allocate = """ allocate = """
if (NULL == %(_z)s || !(%(z_shape_correct)s) || !(%(z_contiguous)s)) if (NULL == {_z} || !({z_shape_correct}) || !({z_contiguous}))
{ {{
npy_intp dims[%(z_ndim)s] = {%(z_shape)s}; npy_intp dims[{z_ndim}] = {{{z_shape}}};
Py_XDECREF(%(_z)s); Py_XDECREF({_z});
%(_z)s = (PyArrayObject*)PyArray_SimpleNew( {_z} = (PyArrayObject*)PyArray_SimpleNew(
%(z_ndim)s, dims, PyArray_TYPE(%(_x)s)); {z_ndim}, dims, PyArray_TYPE({_x}));
if(!%(_z)s) { if(!{_z}) {{
PyErr_SetString(PyExc_MemoryError, PyErr_SetString(PyExc_MemoryError,
"failed to alloc BatchedDot output"); "failed to alloc BatchedDot output");
%(fail)s {fail}
} }}
} }}
""" % locals() """.format(**locals())
# code to reallocate inputs contiguously if necessary # code to reallocate inputs contiguously if necessary
contiguate = [] contiguate = []
...@@ -1860,76 +1860,75 @@ class BatchedDot(COp): ...@@ -1860,76 +1860,75 @@ class BatchedDot(COp):
_contiguous = contiguous(var, ndim) _contiguous = contiguous(var, ndim)
contiguate.append( contiguate.append(
""" """
if (!(%(_contiguous)s)) { if (!({_contiguous})) {{
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s); PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy({var});
if (!_copy) if (!_copy)
%(fail)s {fail}
Py_XDECREF(%(var)s); Py_XDECREF({var});
%(var)s = _copy; {var} = _copy;
} }}
""" """.format(**locals())
% locals()
) )
contiguate = "\n".join(contiguate) contiguate = "\n".join(contiguate)
return """ return """
int type_num = PyArray_DESCR(%(_x)s)->type_num; int type_num = PyArray_DESCR({_x})->type_num;
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes int type_size = PyArray_DESCR({_x})->elsize; // in bytes
if (PyArray_NDIM(%(_x)s) != 3) { if (PyArray_NDIM({_x}) != 3) {{
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"rank(x) != 3. rank(x) is %%d.", "rank(x) != 3. rank(x) is %d.",
PyArray_NDIM(%(_x)s)); PyArray_NDIM({_x}));
%(fail)s; {fail};
} }}
if (PyArray_NDIM(%(_y)s) != 3) { if (PyArray_NDIM({_y}) != 3) {{
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"rank(y) != 3. rank(y) is %%d.", "rank(y) != 3. rank(y) is %d.",
PyArray_NDIM(%(_y)s)); PyArray_NDIM({_y}));
%(fail)s; {fail};
} }}
if (%(_z)s && PyArray_NDIM(%(_z)s) != 3) { if ({_z} && PyArray_NDIM({_z}) != 3) {{
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"rank(z) != 3. rank(z) is %%d.", "rank(z) != 3. rank(z) is %d.",
PyArray_NDIM(%(_z)s)); PyArray_NDIM({_z}));
%(fail)s; {fail};
} }}
// allocate output // allocate output
%(allocate)s {allocate}
// reallocate any noncontiguous arrays or arrays with invalid strides // reallocate any noncontiguous arrays or arrays with invalid strides
%(contiguate)s {contiguate}
if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE) if ((PyArray_DESCR({_x})->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT)) && (PyArray_DESCR({_x})->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;} {{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); {fail};}}
if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE) if ((PyArray_DESCR({_y})->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT)) && (PyArray_DESCR({_y})->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;} {{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); {fail};}}
if ((PyArray_DESCR(%(_z)s)->type_num != NPY_DOUBLE) if ((PyArray_DESCR({_z})->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_z)s)->type_num != NPY_FLOAT)) && (PyArray_DESCR({_z})->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;} {{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); {fail};}}
if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num) if ((PyArray_DESCR({_x})->type_num != PyArray_DESCR({_y})->type_num)
||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_z)s)->type_num)) ||(PyArray_DESCR({_x})->type_num != PyArray_DESCR({_z})->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; } {{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); {fail}; }}
switch (type_num) switch (type_num)
{ {{
case NPY_FLOAT: case NPY_FLOAT:
if (batch_gemm<float>(sgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) { if (batch_gemm<float>(sgemm_, type_size, {_x}, {_y}, {_z})) {{
%(fail)s; {fail};
} }}
break; break;
case NPY_DOUBLE: case NPY_DOUBLE:
if (batch_gemm<double>(dgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) { if (batch_gemm<double>(dgemm_, type_size, {_x}, {_y}, {_z})) {{
%(fail)s; {fail};
} }}
break; break;
} }}
""" % locals() """.format(**locals())
def c_code_cache_version(self): def c_code_cache_version(self):
from pytensor.tensor.blas_headers import blas_header_version from pytensor.tensor.blas_headers import blas_header_version
......
差异被折叠。
...@@ -1097,7 +1097,7 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -1097,7 +1097,7 @@ def ____gemm_code(check_ab, a_init, b_init):
if (PyArray_NDIM(_y) != 2) goto _dot_execute_fallback; if (PyArray_NDIM(_y) != 2) goto _dot_execute_fallback;
if (PyArray_NDIM(_z) != 2) goto _dot_execute_fallback; if (PyArray_NDIM(_z) != 2) goto _dot_execute_fallback;
%(check_ab)s {check_ab}
if ((PyArray_DESCR(_x)->type_num != NPY_DOUBLE) if ((PyArray_DESCR(_x)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(_x)->type_num != NPY_FLOAT)) && (PyArray_DESCR(_x)->type_num != NPY_FLOAT))
...@@ -1117,16 +1117,16 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -1117,16 +1117,16 @@ def ____gemm_code(check_ab, a_init, b_init):
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1])) if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{ {{
error_string = "Input dimensions do not agree"; error_string = "Input dimensions do not agree";
goto _dot_execute_fail; goto _dot_execute_fail;
} }}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] %(mod)s type_size) || (Sx[1] %(mod)s type_size) if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] {mod} type_size) || (Sx[1] {mod} type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] %(mod)s type_size) || (Sy[1] %(mod)s type_size) || (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] {mod} type_size) || (Sy[1] {mod} type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] %(mod)s type_size) || (Sz[1] %(mod)s type_size)) || (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] {mod} type_size) || (Sz[1] {mod} type_size))
{ {{
goto _dot_execute_fallback; goto _dot_execute_fallback;
} }}
/* /*
encode the stride structure of _x,_y,_z into a single integer encode the stride structure of _x,_y,_z into a single integer
...@@ -1146,19 +1146,19 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -1146,19 +1146,19 @@ def ____gemm_code(check_ab, a_init, b_init):
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0]; sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num) switch (type_num)
{ {{
case NPY_FLOAT: case NPY_FLOAT:
{ {{
#define REAL float #define REAL float
float a = %(a_init)s; float a = {a_init};
float b = %(b_init)s; float b = {b_init};
float* x = (float*)PyArray_DATA(_x); float* x = (float*)PyArray_DATA(_x);
float* y = (float*)PyArray_DATA(_y); float* y = (float*)PyArray_DATA(_y);
float* z = (float*)PyArray_DATA(_z); float* z = (float*)PyArray_DATA(_z);
switch(unit) switch(unit)
{ {{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break; case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break; case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break; case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
...@@ -1168,21 +1168,21 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -1168,21 +1168,21 @@ def ____gemm_code(check_ab, a_init, b_init):
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break; case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break; case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback; default: goto _dot_execute_fallback;
}; }};
#undef REAL #undef REAL
} }}
break; break;
case NPY_DOUBLE: case NPY_DOUBLE:
{ {{
#define REAL double #define REAL double
double a = %(a_init)s; double a = {a_init};
double b = %(b_init)s; double b = {b_init};
double* x = (double*)PyArray_DATA(_x); double* x = (double*)PyArray_DATA(_x);
double* y = (double*)PyArray_DATA(_y); double* y = (double*)PyArray_DATA(_y);
double* z = (double*)PyArray_DATA(_z); double* z = (double*)PyArray_DATA(_z);
switch(unit) switch(unit)
{ {{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break; case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break; case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break; case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
...@@ -1192,11 +1192,11 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -1192,11 +1192,11 @@ def ____gemm_code(check_ab, a_init, b_init):
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break; case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break; case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback; default: goto _dot_execute_fallback;
}; }};
#undef REAL #undef REAL
} }}
break; break;
} }}
return 0; //success! return 0; //success!
...@@ -1212,4 +1212,4 @@ def ____gemm_code(check_ab, a_init, b_init): ...@@ -1212,4 +1212,4 @@ def ____gemm_code(check_ab, a_init, b_init):
return -1; return -1;
/* v 1 */ /* v 1 */
""" % locals() """.format(**locals())
...@@ -929,12 +929,12 @@ class Elemwise(OpenMPOp): ...@@ -929,12 +929,12 @@ class Elemwise(OpenMPOp):
# decrease the reference of whatever the output contained # decrease the reference of whatever the output contained
# prior to this # prior to this
alloc += """ alloc += """
if (%(oname)s) { if ({oname}) {{
Py_XDECREF(%(oname)s); Py_XDECREF({oname});
} }}
%(oname)s = %(iname)s; {oname} = {iname};
Py_XINCREF(%(oname)s); Py_XINCREF({oname});
""" % locals() """.format(**locals())
# We alias the scalar variables # We alias the scalar variables
defines += f"#define {oname}_i {iname}_i\n" defines += f"#define {oname}_i {iname}_i\n"
undefs += f"#undef {oname}_i\n" undefs += f"#undef {oname}_i\n"
...@@ -958,12 +958,12 @@ class Elemwise(OpenMPOp): ...@@ -958,12 +958,12 @@ class Elemwise(OpenMPOp):
dict(sub, fail=fail), dict(sub, fail=fail),
) )
code = """ code = """
{ {{
%(defines)s {defines}
%(task_code)s {task_code}
%(undefs)s {undefs}
} }}
""" % locals() """.format(**locals())
loop_orders = orders + [list(range(nnested))] * len(real_onames) loop_orders = orders + [list(range(nnested))] * len(real_onames)
dtypes = idtypes + list(real_odtypes) dtypes = idtypes + list(real_odtypes)
...@@ -995,27 +995,27 @@ class Elemwise(OpenMPOp): ...@@ -995,27 +995,27 @@ class Elemwise(OpenMPOp):
if index != "x": if index != "x":
preloops.setdefault(j, "") preloops.setdefault(j, "")
preloops[j] += ( preloops[j] += (
"%%(lv%(i)s)s_iter = (%(dtype)s*)" "%(lv{i})s_iter = ({dtype}*)"
"(PyArray_DATA(%%(lv%(i)s)s));\n" % locals() "(PyArray_DATA(%(lv{i})s));\n".format(**locals())
) % sub ) % sub
break break
else: # all broadcastable else: # all broadcastable
preloops.setdefault(0, "") preloops.setdefault(0, "")
preloops[0] += ( preloops[0] += (
"%%(lv%(i)s)s_iter = (%(dtype)s*)" "%(lv{i})s_iter = ({dtype}*)"
"(PyArray_DATA(%%(lv%(i)s)s));\n" % locals() "(PyArray_DATA(%(lv{i})s));\n".format(**locals())
) % sub ) % sub
init_array = preloops.get(0, " ") init_array = preloops.get(0, " ")
loop = """ loop = """
{ {{
%(defines)s {defines}
%(init_array)s {init_array}
%(task_decl)s {task_decl}
%(task_code)s {task_code}
%(undefs)s {undefs}
} }}
""" % locals() """.format(**locals())
else: else:
loop = cgen.make_loop( loop = cgen.make_loop(
loop_orders=loop_orders, loop_orders=loop_orders,
...@@ -1076,24 +1076,24 @@ class Elemwise(OpenMPOp): ...@@ -1076,24 +1076,24 @@ class Elemwise(OpenMPOp):
for x, var in zip(inames + onames, inputs + node.outputs): for x, var in zip(inames + onames, inputs + node.outputs):
if not all(s == 1 for s in var.type.shape): if not all(s == 1 for s in var.type.shape):
contig += """ contig += """
dtype_%(x)s * %(x)s_ptr = (dtype_%(x)s*) PyArray_DATA(%(x)s); dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x});
""" % locals() """.format(**locals())
index += """ index += """
dtype_%(x)s& %(x)s_i = %(x)s_ptr[i]; dtype_{x}& {x}_i = {x}_ptr[i];
""" % locals() """.format(**locals())
else: else:
contig += """ contig += """
dtype_%(x)s& %(x)s_i = ((dtype_%(x)s*) PyArray_DATA(%(x)s))[0]; dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0];
""" % locals() """.format(**locals())
if self.openmp: if self.openmp:
contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)}) contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)})
""" """
contig += """ contig += """
for(int i=0; i<n; i++){ for(int i=0; i<n; i++){{
%(index)s {index}
%(task_code)s; {task_code};
} }}
""" % locals() """.format(**locals())
if contig is not None: if contig is not None:
z = list(zip(inames + onames, inputs + node.outputs)) z = list(zip(inames + onames, inputs + node.outputs))
all_broadcastable = all(s == 1 for s in var.type.shape) all_broadcastable = all(s == 1 for s in var.type.shape)
...@@ -1112,12 +1112,12 @@ class Elemwise(OpenMPOp): ...@@ -1112,12 +1112,12 @@ class Elemwise(OpenMPOp):
] ]
) )
loop = """ loop = """
if((%(cond1)s) || (%(cond2)s)){ if(({cond1}) || ({cond2})){{
%(contig)s {contig}
}else{ }}else{{
%(loop)s {loop}
} }}
""" % locals() """.format(**locals())
return decl, checks, alloc, loop, "" return decl, checks, alloc, loop, ""
def c_code(self, node, nodename, inames, onames, sub): def c_code(self, node, nodename, inames, onames, sub):
......
...@@ -176,34 +176,34 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): ...@@ -176,34 +176,34 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
# contiguous dimensions, or the dimension with the smallest # contiguous dimensions, or the dimension with the smallest
# stride. Right now, it is allocated to be C_CONTIGUOUS. # stride. Right now, it is allocated to be C_CONTIGUOUS.
return """ return """
{ {{
npy_intp dims[%(nd)s]; npy_intp dims[{nd}];
//npy_intp* dims = (npy_intp*)malloc(%(nd)s * sizeof(npy_intp)); //npy_intp* dims = (npy_intp*)malloc({nd} * sizeof(npy_intp));
%(init_dims)s {init_dims}
if (!%(olv)s) { if (!{olv}) {{
%(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims, {olv} = (PyArrayObject*)PyArray_EMPTY({nd}, dims,
%(type)s, {type},
%(fortran)s); {fortran});
} }}
else { else {{
PyArray_Dims new_dims; PyArray_Dims new_dims;
new_dims.len = %(nd)s; new_dims.len = {nd};
new_dims.ptr = dims; new_dims.ptr = dims;
PyObject* success = PyArray_Resize(%(olv)s, &new_dims, 0, NPY_CORDER); PyObject* success = PyArray_Resize({olv}, &new_dims, 0, NPY_CORDER);
if (!success) { if (!success) {{
// If we can't resize the ndarray we have we can allocate a new one. // If we can't resize the ndarray we have we can allocate a new one.
PyErr_Clear(); PyErr_Clear();
Py_XDECREF(%(olv)s); Py_XDECREF({olv});
%(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims, %(type)s, 0); {olv} = (PyArrayObject*)PyArray_EMPTY({nd}, dims, {type}, 0);
} else { }} else {{
Py_DECREF(success); Py_DECREF(success);
} }}
} }}
if (!%(olv)s) { if (!{olv}) {{
%(fail)s {fail}
} }}
} }}
""" % dict(locals(), **sub) """.format(**dict(locals(), **sub))
def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None): def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
......
...@@ -69,24 +69,24 @@ class CpuContiguous(COp): ...@@ -69,24 +69,24 @@ class CpuContiguous(COp):
(x,) = inames (x,) = inames
(y,) = onames (y,) = onames
code = """ code = """
if (!PyArray_CHKFLAGS(%(x)s, NPY_ARRAY_C_CONTIGUOUS)){ if (!PyArray_CHKFLAGS({x}, NPY_ARRAY_C_CONTIGUOUS)){{
// check to see if output is contiguous first // check to see if output is contiguous first
if (%(y)s != NULL && if ({y} != NULL &&
PyArray_CompareLists(PyArray_DIMS(%(y)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) && PyArray_CompareLists(PyArray_DIMS({y}), PyArray_DIMS({x}), PyArray_NDIM({x})) &&
PyArray_CHKFLAGS(%(y)s, NPY_ARRAY_C_CONTIGUOUS)){ PyArray_CHKFLAGS({y}, NPY_ARRAY_C_CONTIGUOUS)){{
PyArray_CopyInto(%(y)s, %(x)s); PyArray_CopyInto({y}, {x});
} }}
else{ else{{
Py_XDECREF(%(y)s); Py_XDECREF({y});
%(y)s = PyArray_GETCONTIGUOUS(%(x)s); {y} = PyArray_GETCONTIGUOUS({x});
} }}
} }}
else{ else{{
Py_XINCREF(%(x)s); Py_XINCREF({x});
Py_XDECREF(%(y)s); Py_XDECREF({y});
%(y)s = %(x)s; {y} = {x};
} }}
""" % locals() """.format(**locals())
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -162,12 +162,12 @@ class SearchsortedOp(COp): ...@@ -162,12 +162,12 @@ class SearchsortedOp(COp):
side = sub["params"] side = sub["params"]
fail = sub["fail"] fail = sub["fail"]
return """ return """
PyObject* tmp_%(name)s = PyUnicode_FromString("right"); PyObject* tmp_{name} = PyUnicode_FromString("right");
if (tmp_%(name)s == NULL) if (tmp_{name} == NULL)
%(fail)s; {fail};
right_%(name)s = PyUnicode_Compare(%(side)s, tmp_%(name)s); right_{name} = PyUnicode_Compare({side}, tmp_{name});
Py_DECREF(tmp_%(name)s); Py_DECREF(tmp_{name});
""" % locals() """.format(**locals())
def c_code(self, node, name, inames, onames, sub): def c_code(self, node, name, inames, onames, sub):
sorter = None sorter = None
...@@ -181,17 +181,17 @@ class SearchsortedOp(COp): ...@@ -181,17 +181,17 @@ class SearchsortedOp(COp):
fail = sub["fail"] fail = sub["fail"]
return """ return """
Py_XDECREF(%(z)s); Py_XDECREF({z});
%(z)s = (PyArrayObject*) PyArray_SearchSorted(%(x)s, (PyObject*) %(v)s, {z} = (PyArrayObject*) PyArray_SearchSorted({x}, (PyObject*) {v},
right_%(name)s ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) %(sorter)s); right_{name} ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) {sorter});
if (!%(z)s) if (!{z})
%(fail)s; {fail};
if (PyArray_TYPE(%(z)s) != NPY_INT64){ if (PyArray_TYPE({z}) != NPY_INT64){{
PyObject * tmp = PyArray_Cast(%(z)s, NPY_INT64); PyObject * tmp = PyArray_Cast({z}, NPY_INT64);
Py_XDECREF(%(z)s); Py_XDECREF({z});
%(z)s = (PyArrayObject*) tmp; {z} = (PyArrayObject*) tmp;
} }}
""" % locals() """.format(**locals())
def c_code_cache_version(self): def c_code_cache_version(self):
return (2,) return (2,)
...@@ -351,43 +351,43 @@ class CumOp(COp): ...@@ -351,43 +351,43 @@ class CumOp(COp):
params = sub["params"] params = sub["params"]
code = """ code = """
int axis = %(params)s->c_axis; int axis = {params}->c_axis;
if (axis == 0 && PyArray_NDIM(%(x)s) == 1) if (axis == 0 && PyArray_NDIM({x}) == 1)
axis = NPY_MAXDIMS; axis = NPY_MAXDIMS;
npy_intp shape[1] = { PyArray_SIZE(%(x)s) }; npy_intp shape[1] = {{ PyArray_SIZE({x}) }};
if(axis == NPY_MAXDIMS && !(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0])) if(axis == NPY_MAXDIMS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
{ {{
Py_XDECREF(%(z)s); Py_XDECREF({z});
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_%(x)s)); {z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_{x}));
} }}
else if(axis != NPY_MAXDIMS && !(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)))) else if(axis != NPY_MAXDIMS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
{ {{
Py_XDECREF(%(z)s); Py_XDECREF({z});
%(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE(%(x)s)); {z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
} }}
if (!%(z)s) if (!{z})
%(fail)s; {fail};
{ {{
PyObject * t = NULL; PyObject * t = NULL;
if(%(params)s->mode == MODE_ADD) if({params}->mode == MODE_ADD)
t = PyArray_CumSum( t = PyArray_CumSum(
%(x)s, axis, {x}, axis,
PyArray_TYPE(%(x)s), %(z)s); PyArray_TYPE({x}), {z});
else if(%(params)s->mode == MODE_MUL) else if({params}->mode == MODE_MUL)
t = PyArray_CumProd( t = PyArray_CumProd(
%(x)s, axis, {x}, axis,
PyArray_TYPE(%(x)s), %(z)s); PyArray_TYPE({x}), {z});
if (!t){ if (!t){{
%(fail)s; {fail};
} }}
// Because PyArray_CumSum/CumProd returns a newly created reference on t. // Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t); Py_XDECREF(t);
} }}
""" % locals() """.format(**locals())
return code return code
......
...@@ -420,13 +420,13 @@ class Argmax(COp): ...@@ -420,13 +420,13 @@ class Argmax(COp):
raise NotImplementedError() raise NotImplementedError()
# params is only used here for now # params is only used here for now
axis_code = """ axis_code = """
axis = %(params)s->c_axis; axis = {params}->c_axis;
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){ if(axis > PyArray_NDIM({x})-1 || axis < -PyArray_NDIM({x})){{
PyErr_SetString(PyExc_ValueError, PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument"); "Argmax, bad axis argument");
%(fail)s {fail}
} }}
""" % locals() """.format(**locals())
ret = """ ret = """
int axis; int axis;
......
...@@ -476,47 +476,47 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -476,47 +476,47 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def c_declare(self, name, sub, check_input=True): def c_declare(self, name, sub, check_input=True):
if check_input: if check_input:
check = """ check = """
typedef %(dtype)s dtype_%(name)s; typedef {dtype} dtype_{name};
""" % dict(sub, name=name, dtype=self.dtype_specs()[1]) """.format(**dict(sub, name=name, dtype=self.dtype_specs()[1]))
else: else:
check = "" check = ""
declaration = """ declaration = """
PyArrayObject* %(name)s; PyArrayObject* {name};
""" % dict(sub, name=name, dtype=self.dtype_specs()[1]) """.format(**dict(sub, name=name, dtype=self.dtype_specs()[1]))
return declaration + check return declaration + check
def c_init(self, name, sub): def c_init(self, name, sub):
return """ return """
%(name)s = NULL; {name} = NULL;
""" % dict(sub, name=name, type_num=self.dtype_specs()[2]) """.format(**dict(sub, name=name, type_num=self.dtype_specs()[2]))
def c_extract(self, name, sub, check_input=True, **kwargs): def c_extract(self, name, sub, check_input=True, **kwargs):
if check_input: if check_input:
check = """ check = """
%(name)s = NULL; {name} = NULL;
if (py_%(name)s == Py_None) { if (py_{name} == Py_None) {{
// We can either fail here or set %(name)s to NULL and rely on Ops // We can either fail here or set {name} to NULL and rely on Ops
// using tensors to handle the NULL case, but if they fail to do so // using tensors to handle the NULL case, but if they fail to do so
// they'll end up with nasty segfaults, so this is public service. // they'll end up with nasty segfaults, so this is public service.
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None"); PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None");
%(fail)s {fail}
} }}
if (!PyArray_Check(py_%(name)s)) { if (!PyArray_Check(py_{name})) {{
PyErr_SetString(PyExc_ValueError, "expected an ndarray"); PyErr_SetString(PyExc_ValueError, "expected an ndarray");
%(fail)s {fail}
} }}
// We expect %(type_num)s // We expect {type_num}
if (!PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) { if (!PyArray_ISALIGNED((PyArrayObject*) py_{name})) {{
PyArrayObject * tmp = (PyArrayObject*) py_%(name)s; PyArrayObject * tmp = (PyArrayObject*) py_{name};
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"expected an aligned array of type %%ld " "expected an aligned array of type %ld "
"(%(type_num)s), got non-aligned array of type %%ld" "({type_num}), got non-aligned array of type %ld"
" with %%ld dimensions, with 3 last dims " " with %ld dimensions, with 3 last dims "
"%%ld, %%ld, %%ld" "%ld, %ld, %ld"
" and 3 last strides %%ld %%ld, %%ld.", " and 3 last strides %ld %ld, %ld.",
(long int) %(type_num)s, (long int) {type_num},
(long int) PyArray_TYPE((PyArrayObject*) py_%(name)s), (long int) PyArray_TYPE((PyArrayObject*) py_{name}),
(long int) PyArray_NDIM(tmp), (long int) PyArray_NDIM(tmp),
(long int) (PyArray_NDIM(tmp) >= 3 ? (long int) (PyArray_NDIM(tmp) >= 3 ?
PyArray_DIMS(tmp)[PyArray_NDIM(tmp)-3] : -1), PyArray_DIMS(tmp)[PyArray_NDIM(tmp)-3] : -1),
...@@ -531,74 +531,73 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -531,74 +531,73 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
(long int) (PyArray_NDIM(tmp) >= 1 ? (long int) (PyArray_NDIM(tmp) >= 1 ?
PyArray_STRIDES(tmp)[PyArray_NDIM(tmp)-1] : -1) PyArray_STRIDES(tmp)[PyArray_NDIM(tmp)-1] : -1)
); );
%(fail)s {fail}
} }}
// This is a TypeError to be consistent with DEBUG_MODE // This is a TypeError to be consistent with DEBUG_MODE
// Note: DEBUG_MODE also tells the name of the container // Note: DEBUG_MODE also tells the name of the container
if (PyArray_TYPE((PyArrayObject*) py_%(name)s) != %(type_num)s) { if (PyArray_TYPE((PyArrayObject*) py_{name}) != {type_num}) {{
PyErr_Format(PyExc_TypeError, PyErr_Format(PyExc_TypeError,
"expected type_num %%d (%(type_num)s) got %%d", "expected type_num %d ({type_num}) got %d",
%(type_num)s, PyArray_TYPE((PyArrayObject*) py_%(name)s)); {type_num}, PyArray_TYPE((PyArrayObject*) py_{name}));
%(fail)s {fail}
} }}
""" % dict(sub, name=name, type_num=self.dtype_specs()[2]) """.format(**dict(sub, name=name, type_num=self.dtype_specs()[2]))
else: else:
check = "" check = ""
return ( return (
check check
+ """ + """
%(name)s = (PyArrayObject*)(py_%(name)s); {name} = (PyArrayObject*)(py_{name});
Py_XINCREF(%(name)s); Py_XINCREF({name});
""" """.format(**dict(sub, name=name, type_num=self.dtype_specs()[2]))
% dict(sub, name=name, type_num=self.dtype_specs()[2])
) )
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
return """ return """
if (%(name)s) { if ({name}) {{
Py_XDECREF(%(name)s); Py_XDECREF({name});
} }}
""" % locals() """.format(**locals())
def c_sync(self, name, sub): def c_sync(self, name, sub):
fail = sub["fail"] fail = sub["fail"]
type_num = self.dtype_specs()[2] type_num = self.dtype_specs()[2]
return """ return """
{Py_XDECREF(py_%(name)s);} {{Py_XDECREF(py_{name});}}
if (!%(name)s) { if (!{name}) {{
Py_INCREF(Py_None); Py_INCREF(Py_None);
py_%(name)s = Py_None; py_{name} = Py_None;
} }}
else if ((void*)py_%(name)s != (void*)%(name)s) { else if ((void*)py_{name} != (void*){name}) {{
py_%(name)s = (PyObject*)%(name)s; py_{name} = (PyObject*){name};
} }}
{Py_XINCREF(py_%(name)s);} {{Py_XINCREF(py_{name});}}
if (%(name)s && !PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) { if ({name} && !PyArray_ISALIGNED((PyArrayObject*) py_{name})) {{
PyErr_Format(PyExc_NotImplementedError, PyErr_Format(PyExc_NotImplementedError,
"c_sync: expected an aligned array, got non-aligned array of type %%ld" "c_sync: expected an aligned array, got non-aligned array of type %ld"
" with %%ld dimensions, with 3 last dims " " with %ld dimensions, with 3 last dims "
"%%ld, %%ld, %%ld" "%ld, %ld, %ld"
" and 3 last strides %%ld %%ld, %%ld.", " and 3 last strides %ld %ld, %ld.",
(long int) PyArray_TYPE((PyArrayObject*) py_%(name)s), (long int) PyArray_TYPE((PyArrayObject*) py_{name}),
(long int) PyArray_NDIM(%(name)s), (long int) PyArray_NDIM({name}),
(long int) (PyArray_NDIM(%(name)s) >= 3 ? (long int) (PyArray_NDIM({name}) >= 3 ?
PyArray_DIMS(%(name)s)[PyArray_NDIM(%(name)s)-3] : -1), PyArray_DIMS({name})[PyArray_NDIM({name})-3] : -1),
(long int) (PyArray_NDIM(%(name)s) >= 2 ? (long int) (PyArray_NDIM({name}) >= 2 ?
PyArray_DIMS(%(name)s)[PyArray_NDIM(%(name)s)-2] : -1), PyArray_DIMS({name})[PyArray_NDIM({name})-2] : -1),
(long int) (PyArray_NDIM(%(name)s) >= 1 ? (long int) (PyArray_NDIM({name}) >= 1 ?
PyArray_DIMS(%(name)s)[PyArray_NDIM(%(name)s)-1] : -1), PyArray_DIMS({name})[PyArray_NDIM({name})-1] : -1),
(long int) (PyArray_NDIM(%(name)s) >= 3 ? (long int) (PyArray_NDIM({name}) >= 3 ?
PyArray_STRIDES(%(name)s)[PyArray_NDIM(%(name)s)-3] : -1), PyArray_STRIDES({name})[PyArray_NDIM({name})-3] : -1),
(long int) (PyArray_NDIM(%(name)s) >= 2 ? (long int) (PyArray_NDIM({name}) >= 2 ?
PyArray_STRIDES(%(name)s)[PyArray_NDIM(%(name)s)-2] : -1), PyArray_STRIDES({name})[PyArray_NDIM({name})-2] : -1),
(long int) (PyArray_NDIM(%(name)s) >= 1 ? (long int) (PyArray_NDIM({name}) >= 1 ?
PyArray_STRIDES(%(name)s)[PyArray_NDIM(%(name)s)-1] : -1) PyArray_STRIDES({name})[PyArray_NDIM({name})-1] : -1)
); );
%(fail)s {fail}
} }}
""" % locals() """.format(**locals())
def c_headers(self, **kwargs): def c_headers(self, **kwargs):
return ps.get_scalar_type(self.dtype).c_headers(**kwargs) return ps.get_scalar_type(self.dtype).c_headers(**kwargs)
......
...@@ -103,12 +103,12 @@ class GetItem(COp): ...@@ -103,12 +103,12 @@ class GetItem(COp):
output_name = out[0] output_name = out[0]
fail = sub["fail"] fail = sub["fail"]
return """ return """
%(output_name)s = (typeof %(output_name)s) PyList_GetItem( (PyObject*) %(x_name)s, *((npy_int64 *) PyArray_DATA(%(index)s))); {output_name} = (typeof {output_name}) PyList_GetItem( (PyObject*) {x_name}, *((npy_int64 *) PyArray_DATA({index})));
if(%(output_name)s == NULL){ if({output_name} == NULL){{
%(fail)s {fail}
} }}
Py_INCREF(%(output_name)s); Py_INCREF({output_name});
""" % locals() """.format(**locals())
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
...@@ -170,8 +170,8 @@ class Append(COp): ...@@ -170,8 +170,8 @@ class Append(COp):
fail = sub["fail"] fail = sub["fail"]
if not self.inplace: if not self.inplace:
init = """ init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ; {output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""" % locals() """.format(**locals())
else: else:
init = f""" init = f"""
{output_name} = {x_name}; {output_name} = {x_name};
...@@ -179,15 +179,14 @@ class Append(COp): ...@@ -179,15 +179,14 @@ class Append(COp):
return ( return (
init init
+ """ + """
if(%(output_name)s==NULL){ if({output_name}==NULL){{
%(fail)s {fail}
}; }};
if(PyList_Append( (PyObject*) %(output_name)s,(PyObject*) %(toAppend)s)){ if(PyList_Append( (PyObject*) {output_name},(PyObject*) {toAppend})){{
%(fail)s {fail}
}; }};
Py_INCREF(%(output_name)s); Py_INCREF({output_name});
""" """.format(**locals())
% locals()
) )
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -252,8 +251,8 @@ class Extend(COp): ...@@ -252,8 +251,8 @@ class Extend(COp):
fail = sub["fail"] fail = sub["fail"]
if not self.inplace: if not self.inplace:
init = """ init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ; {output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""" % locals() """.format(**locals())
else: else:
init = f""" init = f"""
{output_name} = {x_name}; {output_name} = {x_name};
...@@ -262,18 +261,17 @@ class Extend(COp): ...@@ -262,18 +261,17 @@ class Extend(COp):
init init
+ """ + """
int i =0; int i =0;
int length = PyList_GET_SIZE((PyObject*) %(toAppend)s); int length = PyList_GET_SIZE((PyObject*) {toAppend});
if(%(output_name)s==NULL){ if({output_name}==NULL){{
%(fail)s {fail}
}; }};
for(i; i < length; i++){ for(i; i < length; i++){{
if(PyList_Append( (PyObject*) %(output_name)s,(PyObject*) PyList_GetItem((PyObject*) %(toAppend)s,i))==-1){ if(PyList_Append( (PyObject*) {output_name},(PyObject*) PyList_GetItem((PyObject*) {toAppend},i))==-1){{
%(fail)s {fail}
}; }};
} }}
Py_INCREF(%(output_name)s); Py_INCREF({output_name});
""" """.format(**locals())
% locals()
) )
def c_code_cache_version_(self): def c_code_cache_version_(self):
...@@ -341,8 +339,8 @@ class Insert(COp): ...@@ -341,8 +339,8 @@ class Insert(COp):
fail = sub["fail"] fail = sub["fail"]
if not self.inplace: if not self.inplace:
init = """ init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ; {output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""" % locals() """.format(**locals())
else: else:
init = f""" init = f"""
{output_name} = {x_name}; {output_name} = {x_name};
...@@ -350,15 +348,14 @@ class Insert(COp): ...@@ -350,15 +348,14 @@ class Insert(COp):
return ( return (
init init
+ """ + """
if(%(output_name)s==NULL){ if({output_name}==NULL){{
%(fail)s {fail}
}; }};
if(PyList_Insert((PyObject*) %(output_name)s, *((npy_int64 *) PyArray_DATA(%(index)s)), (PyObject*) %(toInsert)s)==-1){ if(PyList_Insert((PyObject*) {output_name}, *((npy_int64 *) PyArray_DATA({index})), (PyObject*) {toInsert})==-1){{
%(fail)s {fail}
}; }};
Py_INCREF(%(output_name)s); Py_INCREF({output_name});
""" """.format(**locals())
% locals()
) )
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -470,8 +467,8 @@ class Reverse(COp): ...@@ -470,8 +467,8 @@ class Reverse(COp):
fail = sub["fail"] fail = sub["fail"]
if not self.inplace: if not self.inplace:
init = """ init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ; {output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""" % locals() """.format(**locals())
else: else:
init = f""" init = f"""
{output_name} = {x_name}; {output_name} = {x_name};
...@@ -479,15 +476,14 @@ class Reverse(COp): ...@@ -479,15 +476,14 @@ class Reverse(COp):
return ( return (
init init
+ """ + """
if(%(output_name)s==NULL){ if({output_name}==NULL){{
%(fail)s {fail}
}; }};
if(PyList_Reverse((PyObject*) %(output_name)s)==-1){ if(PyList_Reverse((PyObject*) {output_name})==-1){{
%(fail)s {fail}
}; }};
Py_INCREF(%(output_name)s); Py_INCREF({output_name});
""" """.format(**locals())
% locals()
) )
def c_code_cache_version(self): def c_code_cache_version(self):
...@@ -602,11 +598,11 @@ class Length(COp): ...@@ -602,11 +598,11 @@ class Length(COp):
output_name = out[0] output_name = out[0]
fail = sub["fail"] fail = sub["fail"]
return """ return """
if(!%(output_name)s) if(!{output_name})
%(output_name)s=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0); {output_name}=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA(%(output_name)s))[0]=PyList_Size((PyObject*)%(x_name)s); ((npy_int64*)PyArray_DATA({output_name}))[0]=PyList_Size((PyObject*){x_name});
Py_INCREF(%(output_name)s); Py_INCREF({output_name});
""" % locals() """.format(**locals())
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (1,)
......
...@@ -111,26 +111,25 @@ class TypedListType(CType): ...@@ -111,26 +111,25 @@ class TypedListType(CType):
def c_extract(self, name, sub, check_input=True, **kwargs): def c_extract(self, name, sub, check_input=True, **kwargs):
if check_input: if check_input:
pre = """ pre = """
if (!PyList_Check(py_%(name)s)) { if (!PyList_Check(py_{name})) {{
PyErr_SetString(PyExc_TypeError, "expected a list"); PyErr_SetString(PyExc_TypeError, "expected a list");
%(fail)s {fail}
}""" % dict(name=name, fail=sub["fail"]) }}""".format(**dict(name=name, fail=sub["fail"]))
else: else:
pre = "" pre = ""
return ( return (
pre pre
+ """ + """
%(name)s = (PyListObject*) (py_%(name)s); {name} = (PyListObject*) (py_{name});
""" """.format(**dict(name=name, fail=sub["fail"]))
% dict(name=name, fail=sub["fail"])
) )
def c_sync(self, name, sub): def c_sync(self, name, sub):
return """ return """
Py_XDECREF(py_%(name)s); Py_XDECREF(py_{name});
py_%(name)s = (PyObject*)(%(name)s); py_{name} = (PyObject*)({name});
Py_INCREF(py_%(name)s); Py_INCREF(py_{name});
""" % dict(name=name) """.format(**dict(name=name))
def c_cleanup(self, name, sub): def c_cleanup(self, name, sub):
return "" return ""
......
...@@ -65,38 +65,38 @@ class BROKEN_ON_PURPOSE_Add(COp): ...@@ -65,38 +65,38 @@ class BROKEN_ON_PURPOSE_Add(COp):
a, b = inp a, b = inp
(z,) = out (z,) = out
return """ return """
if (PyArray_NDIM(%(a)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 1"); %(fail)s;} if (PyArray_NDIM({a}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 1"); {fail};}}
if (PyArray_NDIM(%(b)s) != 1) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); %(fail)s;} if (PyArray_NDIM({b}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); {fail};}}
if (PyArray_DESCR(%(a)s)->type_num != NPY_DOUBLE) if (PyArray_DESCR({a})->type_num != NPY_DOUBLE)
{PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_DOUBLE"); %(fail)s;} {{PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_DOUBLE"); {fail};}}
if (PyArray_DESCR(%(b)s)->type_num != NPY_DOUBLE) if (PyArray_DESCR({b})->type_num != NPY_DOUBLE)
{PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_DOUBLE"); %(fail)s;} {{PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_DOUBLE"); {fail};}}
if (PyArray_DIMS(%(a)s)[0] != PyArray_DIMS(%(b)s)[0]) if (PyArray_DIMS({a})[0] != PyArray_DIMS({b})[0])
{PyErr_SetString(PyExc_NotImplementedError, "a and b have different lengths"); %(fail)s;} {{PyErr_SetString(PyExc_NotImplementedError, "a and b have different lengths"); {fail};}}
if ((!%(z)s) if ((!{z})
|| (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(b)s)[0]) || (PyArray_DIMS({z})[0] != PyArray_DIMS({b})[0])
) )
{ {{
{Py_XDECREF(%(z)s);} {{Py_XDECREF({z});}}
npy_intp dims[] = {0}; npy_intp dims[] = {{0}};
dims[0] = PyArray_DIMS(%(b)s)[0]; dims[0] = PyArray_DIMS({b})[0];
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, dims, PyArray_DESCR(%(b)s)->type_num); {z} = (PyArrayObject*) PyArray_SimpleNew(1, dims, PyArray_DESCR({b})->type_num);
} }}
{ {{
for (npy_intp m = 0; m < PyArray_DIMS(%(z)s)[0]; ++m) for (npy_intp m = 0; m < PyArray_DIMS({z})[0]; ++m)
{ {{
((double*)PyArray_GETPTR1(%(z)s, m))[0] ((double*)PyArray_GETPTR1({z}, m))[0]
= 0.5 = 0.5
+ ((double*)PyArray_GETPTR1(%(a)s, m))[0] + ((double*)PyArray_GETPTR1({a}, m))[0]
+ ((double*)PyArray_GETPTR1(%(b)s, m))[0] ; + ((double*)PyArray_GETPTR1({b}, m))[0] ;
} }}
} }}
""" % dict(locals(), **sub) """.format(**dict(locals(), **sub))
# inconsistent is a invalid op, whose perform and c_code do not match # inconsistent is a invalid op, whose perform and c_code do not match
...@@ -634,63 +634,63 @@ class BrokenCImplementationAdd(COp): ...@@ -634,63 +634,63 @@ class BrokenCImplementationAdd(COp):
debug = 0 debug = 0
return """ return """
//printf("executing c_code\\n"); //printf("executing c_code\\n");
if (PyArray_NDIM(%(a)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); %(fail)s;} if (PyArray_NDIM({a}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); {fail};}}
if (PyArray_NDIM(%(b)s) != 2) {PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); %(fail)s;} if (PyArray_NDIM({b}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); {fail};}}
if (PyArray_DESCR(%(a)s)->type_num != NPY_FLOAT) if (PyArray_DESCR({a})->type_num != NPY_FLOAT)
{PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_FLOAT"); %(fail)s;} {{PyErr_SetString(PyExc_NotImplementedError, "a dtype not NPY_FLOAT"); {fail};}}
if (PyArray_DESCR(%(b)s)->type_num != NPY_FLOAT) if (PyArray_DESCR({b})->type_num != NPY_FLOAT)
{PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_FLOAT"); %(fail)s;} {{PyErr_SetString(PyExc_NotImplementedError, "b's dtype not NPY_FLOAT"); {fail};}}
if (PyArray_DIMS(%(a)s)[0] != PyArray_DIMS(%(a)s)[1]) if (PyArray_DIMS({a})[0] != PyArray_DIMS({a})[1])
{PyErr_SetString(PyExc_NotImplementedError, "a is not square"); %(fail)s;} {{PyErr_SetString(PyExc_NotImplementedError, "a is not square"); {fail};}}
if (PyArray_DIMS(%(b)s)[0] != PyArray_DIMS(%(b)s)[1]) if (PyArray_DIMS({b})[0] != PyArray_DIMS({b})[1])
{PyErr_SetString(PyExc_NotImplementedError, "b is not square"); %(fail)s;} {{PyErr_SetString(PyExc_NotImplementedError, "b is not square"); {fail};}}
if (PyArray_DIMS(%(a)s)[0] != PyArray_DIMS(%(b)s)[0]) if (PyArray_DIMS({a})[0] != PyArray_DIMS({b})[0])
{PyErr_SetString(PyExc_NotImplementedError, "a and b have different dimensions"); %(fail)s;} {{PyErr_SetString(PyExc_NotImplementedError, "a and b have different dimensions"); {fail};}}
// We do not check for c_contiguous property here // We do not check for c_contiguous property here
if (%(debug)s) if ({debug})
{ {{
if (!%(z)s) if (!{z})
printf("%(z)s is not there, %%p \\n", %(z)s); printf("{z} is not there, %p \\n", {z});
else if (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(b)s)[0]) else if (PyArray_DIMS({z})[0] != PyArray_DIMS({b})[0])
printf("Dimension 0 mismatch for %(z)s and %(b)s\\n"); printf("Dimension 0 mismatch for {z} and {b}\\n");
else if (PyArray_DIMS(%(z)s)[1] != PyArray_DIMS(%(b)s)[1]) else if (PyArray_DIMS({z})[1] != PyArray_DIMS({b})[1])
printf("Dimension 1 mismatch for %(z)s and %(b)s\\n"); printf("Dimension 1 mismatch for {z} and {b}\\n");
else else
printf("Reusing %(z)s\\n"); printf("Reusing {z}\\n");
} }}
if ((!%(z)s) if ((!{z})
|| (PyArray_DIMS(%(z)s)[0] != PyArray_DIMS(%(b)s)[0]) || (PyArray_DIMS({z})[0] != PyArray_DIMS({b})[0])
|| (PyArray_DIMS(%(z)s)[1] != PyArray_DIMS(%(b)s)[1]) || (PyArray_DIMS({z})[1] != PyArray_DIMS({b})[1])
) )
{ {{
Py_XDECREF(%(z)s); Py_XDECREF({z});
npy_intp dims[] = {0, 0}; npy_intp dims[] = {{0, 0}};
dims[0] = PyArray_DIMS(%(b)s)[0]; dims[0] = PyArray_DIMS({b})[0];
dims[1] = PyArray_DIMS(%(b)s)[1]; dims[1] = PyArray_DIMS({b})[1];
%(z)s = (PyArrayObject*) PyArray_SimpleNew(2, dims, PyArray_DESCR(%(b)s)->type_num); {z} = (PyArrayObject*) PyArray_SimpleNew(2, dims, PyArray_DESCR({b})->type_num);
} }}
// Let us assume that %(z)s is c_contiguous // Let us assume that {z} is c_contiguous
{ {{
dtype_%(z)s * z = ((dtype_%(z)s*)(PyArray_GETPTR2(%(z)s,0,0))); dtype_{z} * z = ((dtype_{z}*)(PyArray_GETPTR2({z},0,0)));
for (int i=0; i<PyArray_DIMS(%(b)s)[0]; i++) for (int i=0; i<PyArray_DIMS({b})[0]; i++)
{ {{
for (int j=0; j<PyArray_DIMS(%(b)s)[1]; j++) for (int j=0; j<PyArray_DIMS({b})[1]; j++)
{ {{
*z = ((float*)PyArray_GETPTR2(%(a)s, i, j))[0] + *z = ((float*)PyArray_GETPTR2({a}, i, j))[0] +
((float*)PyArray_GETPTR2(%(b)s, i, j))[0] ; ((float*)PyArray_GETPTR2({b}, i, j))[0] ;
z++; z++;
} }}
} }}
} }}
""" % dict(locals(), **sub) """.format(**dict(locals(), **sub))
class VecAsRowAndCol(Op): class VecAsRowAndCol(Op):
......
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论