提交 06c5acdf authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Fix UP031 with `ruff --unsafe-fixes`

上级 cada5ad2
......@@ -130,7 +130,7 @@ disable = ["C0330", "C0326"]
[tool.ruff]
select = ["C", "E", "F", "I", "UP", "W"]
ignore = ["C408", "C901", "E501", "E741", "UP031"]
ignore = ["C408", "C901", "E501", "E741"]
exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"]
......
......@@ -465,7 +465,7 @@ class OpFromGraph(Op, HasInnerGraph):
def __str__(self):
name = self.__class__.__name__ if self.name is None else self.name
is_inline = self.is_inline
return "%(name)s{inline=%(is_inline)s}" % locals()
return "{name}{{inline={is_inline}}}".format(**locals())
@config.change_flags(compute_test_value="off")
def _recompute_lop_op(self):
......
......@@ -250,33 +250,30 @@ def struct_gen(args, struct_builders, blocks, sub):
# that holds the type, the value and the traceback. After storing
# the error, we return the failure code so we know which code
# block failed.
do_return = (
"""
if (%(failure_var)s) {
do_return = """
if ({failure_var}) {{
// When there is a failure, this code puts the exception
// in __ERROR.
PyObject* err_type = NULL;
PyObject* err_msg = NULL;
PyObject* err_traceback = NULL;
PyErr_Fetch(&err_type, &err_msg, &err_traceback);
if (!err_type) {err_type = Py_None;Py_INCREF(Py_None);}
if (!err_msg) {err_msg = Py_None; Py_INCREF(Py_None);}
if (!err_traceback) {err_traceback = Py_None; Py_INCREF(Py_None);}
if (!err_type) {{err_type = Py_None;Py_INCREF(Py_None);}}
if (!err_msg) {{err_msg = Py_None; Py_INCREF(Py_None);}}
if (!err_traceback) {{err_traceback = Py_None; Py_INCREF(Py_None);}}
PyObject* old_err_type = PyList_GET_ITEM(__ERROR, 0);
PyObject* old_err_msg = PyList_GET_ITEM(__ERROR, 1);
PyObject* old_err_traceback = PyList_GET_ITEM(__ERROR, 2);
PyList_SET_ITEM(__ERROR, 0, err_type);
PyList_SET_ITEM(__ERROR, 1, err_msg);
PyList_SET_ITEM(__ERROR, 2, err_traceback);
{Py_XDECREF(old_err_type);}
{Py_XDECREF(old_err_msg);}
{Py_XDECREF(old_err_traceback);}
}
{{Py_XDECREF(old_err_type);}}
{{Py_XDECREF(old_err_msg);}}
{{Py_XDECREF(old_err_traceback);}}
}}
// The failure code is returned to index what code block failed.
return %(failure_var)s;
"""
% sub
)
return {failure_var};
""".format(**sub)
sub = dict(sub)
sub.update(locals())
......@@ -284,16 +281,15 @@ def struct_gen(args, struct_builders, blocks, sub):
# TODO: add some error checking to make sure storage_<x> are
# 1-element lists and __ERROR is a 3-elements list.
struct_code = (
"""
namespace {
struct %(name)s {
struct_code = """
namespace {{
struct {name} {{
PyObject* __ERROR;
%(storage_decl)s
%(struct_decl)s
{storage_decl}
{struct_decl}
%(name)s() {
{name}() {{
// This is only somewhat safe because we:
// 1) Are not a virtual class
// 2) Do not use any virtual classes in the members
......@@ -306,32 +302,30 @@ def struct_gen(args, struct_builders, blocks, sub):
#ifndef PYTENSOR_DONT_MEMSET_STRUCT
memset(this, 0, sizeof(*this));
#endif
}
~%(name)s(void) {
}}
~{name}(void) {{
cleanup();
}
}}
int init(PyObject* __ERROR, %(args_decl)s) {
%(storage_incref)s
%(storage_set)s
%(struct_init_head)s
int init(PyObject* __ERROR, {args_decl}) {{
{storage_incref}
{storage_set}
{struct_init_head}
this->__ERROR = __ERROR;
return 0;
}
void cleanup(void) {
%(struct_cleanup)s
%(storage_decref)s
}
int run(void) {
int %(failure_var)s = 0;
%(behavior)s
%(do_return)s
}
};
}
"""
% sub
)
}}
void cleanup(void) {{
{struct_cleanup}
{storage_decref}
}}
int run(void) {{
int {failure_var} = 0;
{behavior}
{do_return}
}}
}};
}}
""".format(**sub)
return struct_code
......@@ -380,9 +374,9 @@ def get_c_init(fgraph, r, name, sub):
pre = (
""
"""
py_%(name)s = Py_None;
{Py_XINCREF(py_%(name)s);}
""" % locals()
py_{name} = Py_None;
{{Py_XINCREF(py_{name});}}
""".format(**locals())
)
return pre + r.type.c_init(name, sub)
......@@ -418,9 +412,9 @@ def get_c_extract(fgraph, r, name, sub):
c_extract = r.type.c_extract(name, sub, False)
pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
{Py_XINCREF(py_%(name)s);}
""" % locals()
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
""".format(**locals())
return pre + c_extract
......@@ -447,9 +441,9 @@ def get_c_extract_out(fgraph, r, name, sub):
c_extract = r.type.c_extract_out(name, sub, check_input, check_broadcast=False)
pre = """
py_%(name)s = PyList_GET_ITEM(storage_%(name)s, 0);
{Py_XINCREF(py_%(name)s);}
""" % locals()
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
""".format(**locals())
return pre + c_extract
......@@ -459,8 +453,8 @@ def get_c_cleanup(fgraph, r, name, sub):
"""
post = """
{Py_XDECREF(py_%(name)s);}
""" % locals()
{{Py_XDECREF(py_{name});}}
""".format(**locals())
return r.type.c_cleanup(name, sub) + post
......@@ -470,14 +464,14 @@ def get_c_sync(fgraph, r, name, sub):
"""
return """
if (!%(failure_var)s) {
%(sync)s
PyObject* old = PyList_GET_ITEM(storage_%(name)s, 0);
{Py_XINCREF(py_%(name)s);}
PyList_SET_ITEM(storage_%(name)s, 0, py_%(name)s);
{Py_XDECREF(old);}
}
""" % dict(sync=r.type.c_sync(name, sub), name=name, **sub)
if (!{failure_var}) {{
{sync}
PyObject* old = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
PyList_SET_ITEM(storage_{name}, 0, py_{name});
{{Py_XDECREF(old);}}
}}
""".format(**dict(sync=r.type.c_sync(name, sub), name=name, **sub))
def apply_policy(fgraph, policy, r, name, sub):
......
......@@ -1950,14 +1950,13 @@ class Compiler:
code = (
"""
%(preamble)s
{preamble}
int main(int argc, char** argv)
{
%(body)s
{{
{body}
return 0;
}
"""
% locals()
}}
""".format(**locals())
).encode()
return cls._try_compile_tmp(
code,
......
......@@ -558,18 +558,20 @@ class CLinkerType(CLinkerObject):
"""
return """
if (py_%(name)s == Py_None)
{
%(c_init_code)s
}
if (py_{name} == Py_None)
{{
{c_init_code}
}}
else
{
%(c_extract_code)s
}
""" % dict(
name=name,
c_init_code=self.c_init(name, sub),
c_extract_code=self.c_extract(name, sub, check_input),
{{
{c_extract_code}
}}
""".format(
**dict(
name=name,
c_init_code=self.c_init(name, sub),
c_extract_code=self.c_extract(name, sub, check_input),
)
)
def c_cleanup(self, name: str, sub: dict[str, str]) -> str:
......
......@@ -596,20 +596,22 @@ class ExternalCOp(COp):
# Generate the C code
return """
%(define_macros)s
{
if (%(func_name)s(%(func_args)s%(params)s) != 0) {
%(fail)s
}
}
%(undef_macros)s
""" % dict(
func_name=self.func_name,
fail=sub["fail"],
params=params,
func_args=self.format_c_function_args(inp, out),
define_macros=define_macros,
undef_macros=undef_macros,
{define_macros}
{{
if ({func_name}({func_args}{params}) != 0) {{
{fail}
}}
}}
{undef_macros}
""".format(
**dict(
func_name=self.func_name,
fail=sub["fail"],
params=params,
func_args=self.format_c_function_args(inp, out),
define_macros=define_macros,
undef_macros=undef_macros,
)
)
else:
if "code" in self.code_sections:
......
......@@ -359,8 +359,7 @@ class ParamsType(CType):
type_name = type_instance.__class__.__name__
if not isinstance(type_instance, CType):
raise TypeError(
'ParamsType: attribute "%s" should inherit from PyTensor CType, got "%s".'
% (attribute_name, type_name)
f'ParamsType: attribute "{attribute_name}" should inherit from PyTensor CType, got "{type_name}".'
)
self.length = len(kwargs)
......@@ -723,15 +722,11 @@ class ParamsType(CType):
c_cleanup_list.append(type_instance.c_cleanup(attribute_name, sub))
c_extract_list.append(
"""
void extract_%(attribute_name)s(PyObject* py_%(attribute_name)s) {
%(extract_code)s
}
f"""
void extract_{attribute_name}(PyObject* py_{attribute_name}) {{
{type_instance.c_extract(attribute_name, sub)}
}}
"""
% {
"attribute_name": attribute_name,
"extract_code": type_instance.c_extract(attribute_name, sub),
}
)
struct_declare = "\n".join(c_declare_list)
......@@ -759,55 +754,57 @@ class ParamsType(CType):
)
)
final_struct_code = """
/** ParamsType %(struct_name)s **/
#ifndef %(struct_name_defined)s
#define %(struct_name_defined)s
struct %(struct_name)s {
/** ParamsType {struct_name} **/
#ifndef {struct_name_defined}
#define {struct_name_defined}
struct {struct_name} {{
/* Attributes, */
int %(struct_name)s_error;
%(struct_declare)s
int {struct_name}_error;
{struct_declare}
/* Constructor. */
%(struct_name)s() {
%(struct_name)s_error = 0;
%(struct_init)s
}
{struct_name}() {{
{struct_name}_error = 0;
{struct_init}
}}
/* Destructor. */
~%(struct_name)s() {
~{struct_name}() {{
// cleanup() is defined below.
cleanup();
}
}}
/* Cleanup method. */
void cleanup() {
%(struct_cleanup)s
}
void cleanup() {{
{struct_cleanup}
}}
/* Extraction methods. */
%(struct_extract)s
{struct_extract}
/* Extract method. */
%(struct_extract_method)s
{struct_extract_method}
/* Other methods. */
void setErrorOccurred() {
++%(struct_name)s_error;
}
int errorOccurred() {
return %(struct_name)s_error;
}
};
void setErrorOccurred() {{
++{struct_name}_error;
}}
int errorOccurred() {{
return {struct_name}_error;
}}
}};
#endif
/** End ParamsType %(struct_name)s **/
""" % dict(
struct_name_defined=struct_name_defined,
struct_name=struct_name,
struct_declare=struct_declare,
struct_init=struct_init,
struct_cleanup=struct_cleanup,
struct_extract=struct_extract,
struct_extract_method=struct_extract_method,
/** End ParamsType {struct_name} **/
""".format(
**dict(
struct_name_defined=struct_name_defined,
struct_name=struct_name,
struct_declare=struct_declare,
struct_init=struct_init,
struct_cleanup=struct_cleanup,
struct_extract=struct_extract,
struct_extract_method=struct_extract_method,
)
)
return sorted(c_support_code_set) + [final_struct_code]
......@@ -822,8 +819,8 @@ class ParamsType(CType):
def c_declare(self, name, sub, check_input=True):
return """
%(struct_name)s* %(name)s;
""" % dict(struct_name=self.name, name=name)
{struct_name}* {name};
""".format(**dict(struct_name=self.name, name=name))
def c_init(self, name, sub):
# NB: It seems c_init() is not called for an op param.
......@@ -841,35 +838,37 @@ class ParamsType(CType):
def c_extract(self, name, sub, check_input=True, **kwargs):
return """
/* Seems c_init() is not called for a op param. So I call `new` here. */
%(name)s = new %(struct_name)s;
{name} = new {struct_name};
{ // This need a separate namespace for Clinker
const char* fields[] = {%(fields_list)s};
if (py_%(name)s == Py_None) {
{{ // This need a separate namespace for Clinker
const char* fields[] = {{{fields_list}}};
if (py_{name} == Py_None) {{
PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None.");
%(fail)s
}
for (int i = 0; i < %(length)s; ++i) {
PyObject* o = PyDict_GetItemString(py_%(name)s, fields[i]);
if (o == NULL) {
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute \\"%%s\\" in object.", fields[i]);
%(fail)s
}
%(name)s->extract(o, i);
if (%(name)s->errorOccurred()) {
{fail}
}}
for (int i = 0; i < {length}; ++i) {{
PyObject* o = PyDict_GetItemString(py_{name}, fields[i]);
if (o == NULL) {{
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute \\"%s\\" in object.", fields[i]);
{fail}
}}
{name}->extract(o, i);
if ({name}->errorOccurred()) {{
/* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */
fprintf(stderr, "\\nParamsType: error when extracting value for attribute \\"%%s\\".\\n", fields[i]);
%(fail)s
}
}
}
""" % dict(
name=name,
struct_name=self.name,
length=self.length,
fail=sub["fail"],
fields_list='"%s"' % '", "'.join(self.fields),
fprintf(stderr, "\\nParamsType: error when extracting value for attribute \\"%s\\".\\n", fields[i]);
{fail}
}}
}}
}}
""".format(
**dict(
name=name,
struct_name=self.name,
length=self.length,
fail=sub["fail"],
fields_list='"%s"' % '", "'.join(self.fields),
)
)
def c_sync(self, name, sub):
......
......@@ -99,11 +99,11 @@ class Generic(CType, Singleton):
def c_sync(self, name, sub):
return """
assert(py_%(name)s->ob_refcnt > 1);
Py_DECREF(py_%(name)s);
py_%(name)s = %(name)s ? %(name)s : Py_None;
Py_INCREF(py_%(name)s);
""" % locals()
assert(py_{name}->ob_refcnt > 1);
Py_DECREF(py_{name});
py_{name} = {name} ? {name} : Py_None;
Py_INCREF(py_{name});
""".format(**locals())
def c_code_cache_version(self):
return (1,)
......@@ -191,17 +191,17 @@ class CDataType(CType[D]):
def c_declare(self, name, sub, check_input=True):
return """
%(ctype)s %(name)s;
""" % dict(ctype=self.ctype, name=name)
{ctype} {name};
""".format(**dict(ctype=self.ctype, name=name))
def c_init(self, name, sub):
return f"{name} = NULL;"
def c_extract(self, name, sub, check_input=True, **kwargs):
return """
%(name)s = (%(ctype)s)PyCapsule_GetPointer(py_%(name)s, NULL);
if (%(name)s == NULL) %(fail)s
""" % dict(name=name, ctype=self.ctype, fail=sub["fail"])
{name} = ({ctype})PyCapsule_GetPointer(py_{name}, NULL);
if ({name} == NULL) {fail}
""".format(**dict(name=name, ctype=self.ctype, fail=sub["fail"]))
def c_sync(self, name, sub):
freefunc = self.freefunc
......@@ -576,39 +576,39 @@ class EnumType(CType, dict):
"""
return """
#ifdef DEBUG
int pytensor_enum_to_string_%(cname)s(%(ctype)s in, char* out) {
int pytensor_enum_to_string_{cname}({ctype} in, char* out) {{
int ret = 0;
switch(in) {
%(cases)s
switch(in) {{
{cases}
default:
PyErr_SetString(PyExc_ValueError, "%(classname)s: unknown enum value.");
PyErr_SetString(PyExc_ValueError, "{classname}: unknown enum value.");
ret = -1;
break;
}
}}
return ret;
}
}}
#endif
""" % dict(
cname=self.cname,
ctype=self.ctype,
classname=type(self).__name__,
cases="".join(
"""
case %(name)s: sprintf(out, "%(name)s"); break;
"""
% dict(name=name)
for name in self
),
""".format(
**dict(
cname=self.cname,
ctype=self.ctype,
classname=type(self).__name__,
cases="".join(
"""
case {name}: sprintf(out, "{name}"); break;
""".format(**dict(name=name))
for name in self
),
)
)
def c_support_code(self, **kwargs):
return (
self.pyint_compat_code
+ "".join(
"""
#define %s %s
f"""
#define {k} {str(self[k])}
"""
% (k, str(self[k]))
for k in sorted(self.keys())
)
+ self.c_to_string()
......@@ -625,15 +625,15 @@ class EnumType(CType, dict):
def c_extract(self, name, sub, check_input=True, **kwargs):
return """
if (PyInt_Check(py_%(name)s)) {
%(name)s = (%(ctype)s)PyInt_AsLong(py_%(name)s);
} else {
%(name)s = (%(ctype)s)PyFloat_AsDouble(py_%(name)s);
}
if (PyErr_Occurred()) {
%(fail)s
}
""" % dict(ctype=self.ctype, name=name, fail=sub["fail"])
if (PyInt_Check(py_{name})) {{
{name} = ({ctype})PyInt_AsLong(py_{name});
}} else {{
{name} = ({ctype})PyFloat_AsDouble(py_{name});
}}
if (PyErr_Occurred()) {{
{fail}
}}
""".format(**dict(ctype=self.ctype, name=name, fail=sub["fail"]))
def c_code_cache_version(self):
return (2, self.ctype, self.cname, tuple(self.items()))
......@@ -754,23 +754,25 @@ class CEnumType(EnumList):
# swapped_dict's keys are integers.
return """
switch(PyInt_AsLong(py_%(name)s)) {
%(cases)s
switch(PyInt_AsLong(py_{name})) {{
{cases}
default:
PyErr_SetString(PyExc_ValueError, "CEnumType: invalid value to map to C constants.");
{%(fail)s}
{{{fail}}}
break;
}
""" % dict(
name=name,
cases="".join(
"""
}}
""".format(
**dict(
name=name,
cases="".join(
"""
case %(i)d: %(name)s = %(constant_cname)s; break;
"""
% dict(i=i, name=name, constant_cname=swapped_dict[i])
for i in sorted(swapped_dict.keys())
),
fail=sub["fail"],
% dict(i=i, name=name, constant_cname=swapped_dict[i])
for i in sorted(swapped_dict.keys())
),
fail=sub["fail"],
)
)
def c_code_cache_version(self):
......
......@@ -141,7 +141,7 @@ def _check_scipy_linalg_matrix(a, func_name):
if isinstance(a, types.Optional):
a = a.type
if not isinstance(a, types.Array):
msg = "%s.%s() only supported for array types" % interp
msg = "{}.{}() only supported for array types".format(*interp)
raise numba.TypingError(msg, highlighting=False)
if a.ndim not in [1, 2]:
msg = "%s.%s() only supported on 1d or 2d arrays, found %s." % (
......@@ -149,7 +149,7 @@ def _check_scipy_linalg_matrix(a, func_name):
)
raise numba.TypingError(msg, highlighting=False)
if not isinstance(a.dtype, (types.Float, types.Complex)):
msg = "%s.%s() only supported on " "float and complex arrays." % interp
msg = "{}.{}() only supported on " "float and complex arrays.".format(*interp)
raise numba.TypingError(msg, highlighting=False)
......
......@@ -646,8 +646,7 @@ def _debugprint(
tot_time_percent = (tot_time_dict[node] / profile.fct_call_time) * 100
print(
"%s --> %8.2es %4.1f%% %8.2es %4.1f%%"
% (
"{} --> {:8.2e}s {:4.1f}% {:8.2e}s {:4.1f}%".format(
var_output,
op_time,
op_time_percent,
......
......@@ -466,40 +466,44 @@ class ScalarType(CType, HasDataType, HasShape):
specs = self.dtype_specs()
if check_input:
pre = """
if (!PyObject_TypeCheck(py_%(name)s, &%(pyarr_type)s))
{
if (!PyObject_TypeCheck(py_{name}, &{pyarr_type}))
{{
PyErr_Format(PyExc_ValueError,
"Scalar check failed (%(dtype)s)");
%(fail)s
}
""" % dict(
sub, name=name, dtype=specs[1], pyarr_type="Py%sArrType_Type" % specs[2]
"Scalar check failed ({dtype})");
{fail}
}}
""".format(
**dict(
sub,
name=name,
dtype=specs[1],
pyarr_type="Py%sArrType_Type" % specs[2],
)
)
else:
pre = ""
return (
pre
+ """
PyArray_ScalarAsCtype(py_%(name)s, &%(name)s);
"""
% dict(sub, name=name)
PyArray_ScalarAsCtype(py_{name}, &{name});
""".format(**dict(sub, name=name))
)
def c_sync(self, name, sub):
specs = self.dtype_specs()
return """
Py_XDECREF(py_%(name)s);
py_%(name)s = PyArrayScalar_New(%(cls)s);
if (!py_%(name)s)
{
Py_XDECREF(py_{name});
py_{name} = PyArrayScalar_New({cls});
if (!py_{name})
{{
Py_XINCREF(Py_None);
py_%(name)s = Py_None;
py_{name} = Py_None;
PyErr_Format(PyExc_MemoryError,
"Instantiation of new Python scalar failed (%(dtype)s)");
%(fail)s
}
PyArrayScalar_ASSIGN(py_%(name)s, %(cls)s, %(name)s);
""" % dict(sub, name=name, dtype=specs[1], cls=specs[2])
"Instantiation of new Python scalar failed ({dtype})");
{fail}
}}
PyArrayScalar_ASSIGN(py_{name}, {cls}, {name});
""".format(**dict(sub, name=name, dtype=specs[1], cls=specs[2]))
def c_cleanup(self, name, sub):
return ""
......
......@@ -620,8 +620,8 @@ class Chi2SF(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
return """%(z)s =
(%(dtype)s) 1 - GammaP(%(k)s/2., %(x)s/2.);""" % locals()
return """{z} =
({dtype}) 1 - GammaP({k}/2., {x}/2.);""".format(**locals())
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
......@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
return """%(z)s =
(%(dtype)s) GammaP(%(k)s, %(x)s);""" % locals()
return """{z} =
({dtype}) GammaP({k}, {x});""".format(**locals())
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
......@@ -712,8 +712,8 @@ class GammaIncC(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
return """%(z)s =
(%(dtype)s) GammaQ(%(k)s, %(x)s);""" % locals()
return """{z} =
({dtype}) GammaQ({k}, {x});""".format(**locals())
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
......@@ -1018,8 +1018,8 @@ class GammaU(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
return """%(z)s =
(%(dtype)s) upperGamma(%(k)s, %(x)s);""" % locals()
return """{z} =
({dtype}) upperGamma({k}, {x});""".format(**locals())
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
......@@ -1056,8 +1056,8 @@ class GammaL(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
return """%(z)s =
(%(dtype)s) lowerGamma(%(k)s, %(x)s);""" % locals()
return """{z} =
({dtype}) lowerGamma({k}, {x});""".format(**locals())
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
......
......@@ -3364,8 +3364,7 @@ def profile_printer(
total_scan_fct_time += scan_fct_time
total_scan_op_time += scan_op_time
print(
" %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%"
% (
" {:5.1f}s {:5.1f}s {:5.1f}s {:5.1f}% {:5.1f}%".format(
v,
scan_fct_time,
scan_op_time,
......@@ -3385,8 +3384,7 @@ def profile_printer(
print(" No scan have its inner profile enabled.", file=file)
else:
print(
"total %5.1fs %5.1fs %5.1fs %5.1f%% %5.1f%%"
% (
"total {:5.1f}s {:5.1f}s {:5.1f}s {:5.1f}% {:5.1f}%".format(
total_super_scan_time,
total_scan_fct_time,
total_scan_op_time,
......
差异被折叠。
差异被折叠。
......@@ -1840,19 +1840,19 @@ class BatchedDot(COp):
z_shape = ", ".join(z_dims)
z_contiguous = contiguous(_z, z_ndim)
allocate = """
if (NULL == %(_z)s || !(%(z_shape_correct)s) || !(%(z_contiguous)s))
{
npy_intp dims[%(z_ndim)s] = {%(z_shape)s};
Py_XDECREF(%(_z)s);
%(_z)s = (PyArrayObject*)PyArray_SimpleNew(
%(z_ndim)s, dims, PyArray_TYPE(%(_x)s));
if(!%(_z)s) {
if (NULL == {_z} || !({z_shape_correct}) || !({z_contiguous}))
{{
npy_intp dims[{z_ndim}] = {{{z_shape}}};
Py_XDECREF({_z});
{_z} = (PyArrayObject*)PyArray_SimpleNew(
{z_ndim}, dims, PyArray_TYPE({_x}));
if(!{_z}) {{
PyErr_SetString(PyExc_MemoryError,
"failed to alloc BatchedDot output");
%(fail)s
}
}
""" % locals()
{fail}
}}
}}
""".format(**locals())
# code to reallocate inputs contiguously if necessary
contiguate = []
......@@ -1860,76 +1860,75 @@ class BatchedDot(COp):
_contiguous = contiguous(var, ndim)
contiguate.append(
"""
if (!(%(_contiguous)s)) {
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy(%(var)s);
if (!({_contiguous})) {{
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy({var});
if (!_copy)
%(fail)s
Py_XDECREF(%(var)s);
%(var)s = _copy;
}
"""
% locals()
{fail}
Py_XDECREF({var});
{var} = _copy;
}}
""".format(**locals())
)
contiguate = "\n".join(contiguate)
return """
int type_num = PyArray_DESCR(%(_x)s)->type_num;
int type_size = PyArray_DESCR(%(_x)s)->elsize; // in bytes
int type_num = PyArray_DESCR({_x})->type_num;
int type_size = PyArray_DESCR({_x})->elsize; // in bytes
if (PyArray_NDIM(%(_x)s) != 3) {
if (PyArray_NDIM({_x}) != 3) {{
PyErr_Format(PyExc_NotImplementedError,
"rank(x) != 3. rank(x) is %%d.",
PyArray_NDIM(%(_x)s));
%(fail)s;
}
if (PyArray_NDIM(%(_y)s) != 3) {
"rank(x) != 3. rank(x) is %d.",
PyArray_NDIM({_x}));
{fail};
}}
if (PyArray_NDIM({_y}) != 3) {{
PyErr_Format(PyExc_NotImplementedError,
"rank(y) != 3. rank(y) is %%d.",
PyArray_NDIM(%(_y)s));
%(fail)s;
}
if (%(_z)s && PyArray_NDIM(%(_z)s) != 3) {
"rank(y) != 3. rank(y) is %d.",
PyArray_NDIM({_y}));
{fail};
}}
if ({_z} && PyArray_NDIM({_z}) != 3) {{
PyErr_Format(PyExc_NotImplementedError,
"rank(z) != 3. rank(z) is %%d.",
PyArray_NDIM(%(_z)s));
%(fail)s;
}
"rank(z) != 3. rank(z) is %d.",
PyArray_NDIM({_z}));
{fail};
}}
// allocate output
%(allocate)s
{allocate}
// reallocate any noncontiguous arrays or arrays with invalid strides
%(contiguate)s
{contiguate}
if ((PyArray_DESCR(%(_x)s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_x)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); %(fail)s;}
if ((PyArray_DESCR({_x})->type_num != NPY_DOUBLE)
&& (PyArray_DESCR({_x})->type_num != NPY_FLOAT))
{{PyErr_SetString(PyExc_NotImplementedError, "type(x) is not double or float"); {fail};}}
if ((PyArray_DESCR(%(_y)s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_y)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); %(fail)s;}
if ((PyArray_DESCR({_y})->type_num != NPY_DOUBLE)
&& (PyArray_DESCR({_y})->type_num != NPY_FLOAT))
{{PyErr_SetString(PyExc_NotImplementedError, "type(y) is not double or float"); {fail};}}
if ((PyArray_DESCR(%(_z)s)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(%(_z)s)->type_num != NPY_FLOAT))
{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); %(fail)s;}
if ((PyArray_DESCR({_z})->type_num != NPY_DOUBLE)
&& (PyArray_DESCR({_z})->type_num != NPY_FLOAT))
{{PyErr_SetString(PyExc_NotImplementedError, "type(z) is not double or float"); {fail};}}
if ((PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_y)s)->type_num)
||(PyArray_DESCR(%(_x)s)->type_num != PyArray_DESCR(%(_z)s)->type_num))
{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); %(fail)s; }
if ((PyArray_DESCR({_x})->type_num != PyArray_DESCR({_y})->type_num)
||(PyArray_DESCR({_x})->type_num != PyArray_DESCR({_z})->type_num))
{{ PyErr_SetString(PyExc_NotImplementedError, "type(x), type(y), type(z) are not all the same"); {fail}; }}
switch (type_num)
{
{{
case NPY_FLOAT:
if (batch_gemm<float>(sgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) {
%(fail)s;
}
if (batch_gemm<float>(sgemm_, type_size, {_x}, {_y}, {_z})) {{
{fail};
}}
break;
case NPY_DOUBLE:
if (batch_gemm<double>(dgemm_, type_size, %(_x)s, %(_y)s, %(_z)s)) {
%(fail)s;
}
if (batch_gemm<double>(dgemm_, type_size, {_x}, {_y}, {_z})) {{
{fail};
}}
break;
}
""" % locals()
}}
""".format(**locals())
def c_code_cache_version(self):
from pytensor.tensor.blas_headers import blas_header_version
......
差异被折叠。
......@@ -1097,7 +1097,7 @@ def ____gemm_code(check_ab, a_init, b_init):
if (PyArray_NDIM(_y) != 2) goto _dot_execute_fallback;
if (PyArray_NDIM(_z) != 2) goto _dot_execute_fallback;
%(check_ab)s
{check_ab}
if ((PyArray_DESCR(_x)->type_num != NPY_DOUBLE)
&& (PyArray_DESCR(_x)->type_num != NPY_FLOAT))
......@@ -1117,16 +1117,16 @@ def ____gemm_code(check_ab, a_init, b_init):
if ((Nx[0] != Nz[0]) || (Nx[1] != Ny[0]) || (Ny[1] != Nz[1]))
{
{{
error_string = "Input dimensions do not agree";
goto _dot_execute_fail;
}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] %(mod)s type_size) || (Sx[1] %(mod)s type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] %(mod)s type_size) || (Sy[1] %(mod)s type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] %(mod)s type_size) || (Sz[1] %(mod)s type_size))
{
}}
if ((Sx[0] < 1) || (Sx[1] < 1) || (Sx[0] {mod} type_size) || (Sx[1] {mod} type_size)
|| (Sy[0] < 1) || (Sy[1] < 1) || (Sy[0] {mod} type_size) || (Sy[1] {mod} type_size)
|| (Sz[0] < 1) || (Sz[1] < 1) || (Sz[0] {mod} type_size) || (Sz[1] {mod} type_size))
{{
goto _dot_execute_fallback;
}
}}
/*
encode the stride structure of _x,_y,_z into a single integer
......@@ -1146,19 +1146,19 @@ def ____gemm_code(check_ab, a_init, b_init):
sz_1 = (Nz[1] > 1) ? Sz[1]/type_size : Nz[0];
switch (type_num)
{
{{
case NPY_FLOAT:
{
{{
#define REAL float
float a = %(a_init)s;
float b = %(b_init)s;
float a = {a_init};
float b = {b_init};
float* x = (float*)PyArray_DATA(_x);
float* y = (float*)PyArray_DATA(_y);
float* z = (float*)PyArray_DATA(_z);
switch(unit)
{
{{
case 0x000: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_sgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_sgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
......@@ -1168,21 +1168,21 @@ def ____gemm_code(check_ab, a_init, b_init):
case 0x110: cblas_sgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_sgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
}};
#undef REAL
}
}}
break;
case NPY_DOUBLE:
{
{{
#define REAL double
double a = %(a_init)s;
double b = %(b_init)s;
double a = {a_init};
double b = {b_init};
double* x = (double*)PyArray_DATA(_x);
double* y = (double*)PyArray_DATA(_y);
double* z = (double*)PyArray_DATA(_z);
switch(unit)
{
{{
case 0x000: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_0, b, z, sz_0); break;
case 0x001: cblas_dgemm(CblasRowMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_0, b, z, sz_0); break;
case 0x010: cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_0); break;
......@@ -1192,11 +1192,11 @@ def ____gemm_code(check_ab, a_init, b_init):
case 0x110: cblas_dgemm(CblasColMajor, CblasTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_0, y, sy_1, b, z, sz_1); break;
case 0x111: cblas_dgemm(CblasColMajor, CblasNoTrans, CblasNoTrans, Nz[0], Nz[1], Nx[1], a, x, sx_1, y, sy_1, b, z, sz_1); break;
default: goto _dot_execute_fallback;
};
}};
#undef REAL
}
}}
break;
}
}}
return 0; //success!
......@@ -1212,4 +1212,4 @@ def ____gemm_code(check_ab, a_init, b_init):
return -1;
/* v 1 */
""" % locals()
""".format(**locals())
......@@ -929,12 +929,12 @@ class Elemwise(OpenMPOp):
# decrease the reference of whatever the output contained
# prior to this
alloc += """
if (%(oname)s) {
Py_XDECREF(%(oname)s);
}
%(oname)s = %(iname)s;
Py_XINCREF(%(oname)s);
""" % locals()
if ({oname}) {{
Py_XDECREF({oname});
}}
{oname} = {iname};
Py_XINCREF({oname});
""".format(**locals())
# We alias the scalar variables
defines += f"#define {oname}_i {iname}_i\n"
undefs += f"#undef {oname}_i\n"
......@@ -958,12 +958,12 @@ class Elemwise(OpenMPOp):
dict(sub, fail=fail),
)
code = """
{
%(defines)s
%(task_code)s
%(undefs)s
}
""" % locals()
{{
{defines}
{task_code}
{undefs}
}}
""".format(**locals())
loop_orders = orders + [list(range(nnested))] * len(real_onames)
dtypes = idtypes + list(real_odtypes)
......@@ -995,27 +995,27 @@ class Elemwise(OpenMPOp):
if index != "x":
preloops.setdefault(j, "")
preloops[j] += (
"%%(lv%(i)s)s_iter = (%(dtype)s*)"
"(PyArray_DATA(%%(lv%(i)s)s));\n" % locals()
"%(lv{i})s_iter = ({dtype}*)"
"(PyArray_DATA(%(lv{i})s));\n".format(**locals())
) % sub
break
else: # all broadcastable
preloops.setdefault(0, "")
preloops[0] += (
"%%(lv%(i)s)s_iter = (%(dtype)s*)"
"(PyArray_DATA(%%(lv%(i)s)s));\n" % locals()
"%(lv{i})s_iter = ({dtype}*)"
"(PyArray_DATA(%(lv{i})s));\n".format(**locals())
) % sub
init_array = preloops.get(0, " ")
loop = """
{
%(defines)s
%(init_array)s
%(task_decl)s
%(task_code)s
%(undefs)s
}
""" % locals()
{{
{defines}
{init_array}
{task_decl}
{task_code}
{undefs}
}}
""".format(**locals())
else:
loop = cgen.make_loop(
loop_orders=loop_orders,
......@@ -1076,24 +1076,24 @@ class Elemwise(OpenMPOp):
for x, var in zip(inames + onames, inputs + node.outputs):
if not all(s == 1 for s in var.type.shape):
contig += """
dtype_%(x)s * %(x)s_ptr = (dtype_%(x)s*) PyArray_DATA(%(x)s);
""" % locals()
dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x});
""".format(**locals())
index += """
dtype_%(x)s& %(x)s_i = %(x)s_ptr[i];
""" % locals()
dtype_{x}& {x}_i = {x}_ptr[i];
""".format(**locals())
else:
contig += """
dtype_%(x)s& %(x)s_i = ((dtype_%(x)s*) PyArray_DATA(%(x)s))[0];
""" % locals()
dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0];
""".format(**locals())
if self.openmp:
contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)})
"""
contig += """
for(int i=0; i<n; i++){
%(index)s
%(task_code)s;
}
""" % locals()
for(int i=0; i<n; i++){{
{index}
{task_code};
}}
""".format(**locals())
if contig is not None:
z = list(zip(inames + onames, inputs + node.outputs))
all_broadcastable = all(s == 1 for s in var.type.shape)
......@@ -1112,12 +1112,12 @@ class Elemwise(OpenMPOp):
]
)
loop = """
if((%(cond1)s) || (%(cond2)s)){
%(contig)s
}else{
%(loop)s
}
""" % locals()
if(({cond1}) || ({cond2})){{
{contig}
}}else{{
{loop}
}}
""".format(**locals())
return decl, checks, alloc, loop, ""
def c_code(self, node, nodename, inames, onames, sub):
......
......@@ -176,34 +176,34 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
# contiguous dimensions, or the dimension with the smallest
# stride. Right now, it is allocated to be C_CONTIGUOUS.
return """
{
npy_intp dims[%(nd)s];
//npy_intp* dims = (npy_intp*)malloc(%(nd)s * sizeof(npy_intp));
%(init_dims)s
if (!%(olv)s) {
%(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims,
%(type)s,
%(fortran)s);
}
else {
{{
npy_intp dims[{nd}];
//npy_intp* dims = (npy_intp*)malloc({nd} * sizeof(npy_intp));
{init_dims}
if (!{olv}) {{
{olv} = (PyArrayObject*)PyArray_EMPTY({nd}, dims,
{type},
{fortran});
}}
else {{
PyArray_Dims new_dims;
new_dims.len = %(nd)s;
new_dims.len = {nd};
new_dims.ptr = dims;
PyObject* success = PyArray_Resize(%(olv)s, &new_dims, 0, NPY_CORDER);
if (!success) {
PyObject* success = PyArray_Resize({olv}, &new_dims, 0, NPY_CORDER);
if (!success) {{
// If we can't resize the ndarray we have we can allocate a new one.
PyErr_Clear();
Py_XDECREF(%(olv)s);
%(olv)s = (PyArrayObject*)PyArray_EMPTY(%(nd)s, dims, %(type)s, 0);
} else {
Py_XDECREF({olv});
{olv} = (PyArrayObject*)PyArray_EMPTY({nd}, dims, {type}, 0);
}} else {{
Py_DECREF(success);
}
}
if (!%(olv)s) {
%(fail)s
}
}
""" % dict(locals(), **sub)
}}
}}
if (!{olv}) {{
{fail}
}}
}}
""".format(**dict(locals(), **sub))
def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
......
......@@ -69,24 +69,24 @@ class CpuContiguous(COp):
(x,) = inames
(y,) = onames
code = """
if (!PyArray_CHKFLAGS(%(x)s, NPY_ARRAY_C_CONTIGUOUS)){
if (!PyArray_CHKFLAGS({x}, NPY_ARRAY_C_CONTIGUOUS)){{
// check to see if output is contiguous first
if (%(y)s != NULL &&
PyArray_CompareLists(PyArray_DIMS(%(y)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s)) &&
PyArray_CHKFLAGS(%(y)s, NPY_ARRAY_C_CONTIGUOUS)){
PyArray_CopyInto(%(y)s, %(x)s);
}
else{
Py_XDECREF(%(y)s);
%(y)s = PyArray_GETCONTIGUOUS(%(x)s);
}
}
else{
Py_XINCREF(%(x)s);
Py_XDECREF(%(y)s);
%(y)s = %(x)s;
}
""" % locals()
if ({y} != NULL &&
PyArray_CompareLists(PyArray_DIMS({y}), PyArray_DIMS({x}), PyArray_NDIM({x})) &&
PyArray_CHKFLAGS({y}, NPY_ARRAY_C_CONTIGUOUS)){{
PyArray_CopyInto({y}, {x});
}}
else{{
Py_XDECREF({y});
{y} = PyArray_GETCONTIGUOUS({x});
}}
}}
else{{
Py_XINCREF({x});
Py_XDECREF({y});
{y} = {x};
}}
""".format(**locals())
return code
def c_code_cache_version(self):
......@@ -162,12 +162,12 @@ class SearchsortedOp(COp):
side = sub["params"]
fail = sub["fail"]
return """
PyObject* tmp_%(name)s = PyUnicode_FromString("right");
if (tmp_%(name)s == NULL)
%(fail)s;
right_%(name)s = PyUnicode_Compare(%(side)s, tmp_%(name)s);
Py_DECREF(tmp_%(name)s);
""" % locals()
PyObject* tmp_{name} = PyUnicode_FromString("right");
if (tmp_{name} == NULL)
{fail};
right_{name} = PyUnicode_Compare({side}, tmp_{name});
Py_DECREF(tmp_{name});
""".format(**locals())
def c_code(self, node, name, inames, onames, sub):
sorter = None
......@@ -181,17 +181,17 @@ class SearchsortedOp(COp):
fail = sub["fail"]
return """
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SearchSorted(%(x)s, (PyObject*) %(v)s,
right_%(name)s ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) %(sorter)s);
if (!%(z)s)
%(fail)s;
if (PyArray_TYPE(%(z)s) != NPY_INT64){
PyObject * tmp = PyArray_Cast(%(z)s, NPY_INT64);
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) tmp;
}
""" % locals()
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SearchSorted({x}, (PyObject*) {v},
right_{name} ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) {sorter});
if (!{z})
{fail};
if (PyArray_TYPE({z}) != NPY_INT64){{
PyObject * tmp = PyArray_Cast({z}, NPY_INT64);
Py_XDECREF({z});
{z} = (PyArrayObject*) tmp;
}}
""".format(**locals())
def c_code_cache_version(self):
return (2,)
......@@ -351,43 +351,43 @@ class CumOp(COp):
params = sub["params"]
code = """
int axis = %(params)s->c_axis;
if (axis == 0 && PyArray_NDIM(%(x)s) == 1)
int axis = {params}->c_axis;
if (axis == 0 && PyArray_NDIM({x}) == 1)
axis = NPY_MAXDIMS;
npy_intp shape[1] = { PyArray_SIZE(%(x)s) };
if(axis == NPY_MAXDIMS && !(%(z)s && PyArray_DIMS(%(z)s)[0] == shape[0]))
{
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_%(x)s));
}
else if(axis != NPY_MAXDIMS && !(%(z)s && PyArray_CompareLists(PyArray_DIMS(%(z)s), PyArray_DIMS(%(x)s), PyArray_NDIM(%(x)s))))
{
Py_XDECREF(%(z)s);
%(z)s = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM(%(x)s), PyArray_DIMS(%(x)s), PyArray_TYPE(%(x)s));
}
if (!%(z)s)
%(fail)s;
{
npy_intp shape[1] = {{ PyArray_SIZE({x}) }};
if(axis == NPY_MAXDIMS && !({z} && PyArray_DIMS({z})[0] == shape[0]))
{{
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SimpleNew(1, shape, PyArray_TYPE((PyArrayObject*) py_{x}));
}}
else if(axis != NPY_MAXDIMS && !({z} && PyArray_CompareLists(PyArray_DIMS({z}), PyArray_DIMS({x}), PyArray_NDIM({x}))))
{{
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SimpleNew(PyArray_NDIM({x}), PyArray_DIMS({x}), PyArray_TYPE({x}));
}}
if (!{z})
{fail};
{{
PyObject * t = NULL;
if(%(params)s->mode == MODE_ADD)
if({params}->mode == MODE_ADD)
t = PyArray_CumSum(
%(x)s, axis,
PyArray_TYPE(%(x)s), %(z)s);
else if(%(params)s->mode == MODE_MUL)
{x}, axis,
PyArray_TYPE({x}), {z});
else if({params}->mode == MODE_MUL)
t = PyArray_CumProd(
%(x)s, axis,
PyArray_TYPE(%(x)s), %(z)s);
{x}, axis,
PyArray_TYPE({x}), {z});
if (!t){
%(fail)s;
}
if (!t){{
{fail};
}}
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
}
""" % locals()
}}
""".format(**locals())
return code
......
......@@ -420,13 +420,13 @@ class Argmax(COp):
raise NotImplementedError()
# params is only used here for now
axis_code = """
axis = %(params)s->c_axis;
if(axis > PyArray_NDIM(%(x)s)-1 || axis < -PyArray_NDIM(%(x)s)){
axis = {params}->c_axis;
if(axis > PyArray_NDIM({x})-1 || axis < -PyArray_NDIM({x})){{
PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument");
%(fail)s
}
""" % locals()
{fail}
}}
""".format(**locals())
ret = """
int axis;
......
......@@ -476,47 +476,47 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def c_declare(self, name, sub, check_input=True):
if check_input:
check = """
typedef %(dtype)s dtype_%(name)s;
""" % dict(sub, name=name, dtype=self.dtype_specs()[1])
typedef {dtype} dtype_{name};
""".format(**dict(sub, name=name, dtype=self.dtype_specs()[1]))
else:
check = ""
declaration = """
PyArrayObject* %(name)s;
""" % dict(sub, name=name, dtype=self.dtype_specs()[1])
PyArrayObject* {name};
""".format(**dict(sub, name=name, dtype=self.dtype_specs()[1]))
return declaration + check
def c_init(self, name, sub):
return """
%(name)s = NULL;
""" % dict(sub, name=name, type_num=self.dtype_specs()[2])
{name} = NULL;
""".format(**dict(sub, name=name, type_num=self.dtype_specs()[2]))
def c_extract(self, name, sub, check_input=True, **kwargs):
if check_input:
check = """
%(name)s = NULL;
if (py_%(name)s == Py_None) {
// We can either fail here or set %(name)s to NULL and rely on Ops
{name} = NULL;
if (py_{name} == Py_None) {{
// We can either fail here or set {name} to NULL and rely on Ops
// using tensors to handle the NULL case, but if they fail to do so
// they'll end up with nasty segfaults, so this is public service.
PyErr_SetString(PyExc_ValueError, "expected an ndarray, not None");
%(fail)s
}
if (!PyArray_Check(py_%(name)s)) {
{fail}
}}
if (!PyArray_Check(py_{name})) {{
PyErr_SetString(PyExc_ValueError, "expected an ndarray");
%(fail)s
}
// We expect %(type_num)s
if (!PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) {
PyArrayObject * tmp = (PyArrayObject*) py_%(name)s;
{fail}
}}
// We expect {type_num}
if (!PyArray_ISALIGNED((PyArrayObject*) py_{name})) {{
PyArrayObject * tmp = (PyArrayObject*) py_{name};
PyErr_Format(PyExc_NotImplementedError,
"expected an aligned array of type %%ld "
"(%(type_num)s), got non-aligned array of type %%ld"
" with %%ld dimensions, with 3 last dims "
"%%ld, %%ld, %%ld"
" and 3 last strides %%ld %%ld, %%ld.",
(long int) %(type_num)s,
(long int) PyArray_TYPE((PyArrayObject*) py_%(name)s),
"expected an aligned array of type %ld "
"({type_num}), got non-aligned array of type %ld"
" with %ld dimensions, with 3 last dims "
"%ld, %ld, %ld"
" and 3 last strides %ld %ld, %ld.",
(long int) {type_num},
(long int) PyArray_TYPE((PyArrayObject*) py_{name}),
(long int) PyArray_NDIM(tmp),
(long int) (PyArray_NDIM(tmp) >= 3 ?
PyArray_DIMS(tmp)[PyArray_NDIM(tmp)-3] : -1),
......@@ -531,74 +531,73 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
(long int) (PyArray_NDIM(tmp) >= 1 ?
PyArray_STRIDES(tmp)[PyArray_NDIM(tmp)-1] : -1)
);
%(fail)s
}
{fail}
}}
// This is a TypeError to be consistent with DEBUG_MODE
// Note: DEBUG_MODE also tells the name of the container
if (PyArray_TYPE((PyArrayObject*) py_%(name)s) != %(type_num)s) {
if (PyArray_TYPE((PyArrayObject*) py_{name}) != {type_num}) {{
PyErr_Format(PyExc_TypeError,
"expected type_num %%d (%(type_num)s) got %%d",
%(type_num)s, PyArray_TYPE((PyArrayObject*) py_%(name)s));
%(fail)s
}
""" % dict(sub, name=name, type_num=self.dtype_specs()[2])
"expected type_num %d ({type_num}) got %d",
{type_num}, PyArray_TYPE((PyArrayObject*) py_{name}));
{fail}
}}
""".format(**dict(sub, name=name, type_num=self.dtype_specs()[2]))
else:
check = ""
return (
check
+ """
%(name)s = (PyArrayObject*)(py_%(name)s);
Py_XINCREF(%(name)s);
"""
% dict(sub, name=name, type_num=self.dtype_specs()[2])
{name} = (PyArrayObject*)(py_{name});
Py_XINCREF({name});
""".format(**dict(sub, name=name, type_num=self.dtype_specs()[2]))
)
def c_cleanup(self, name, sub):
return """
if (%(name)s) {
Py_XDECREF(%(name)s);
}
""" % locals()
if ({name}) {{
Py_XDECREF({name});
}}
""".format(**locals())
def c_sync(self, name, sub):
fail = sub["fail"]
type_num = self.dtype_specs()[2]
return """
{Py_XDECREF(py_%(name)s);}
if (!%(name)s) {
{{Py_XDECREF(py_{name});}}
if (!{name}) {{
Py_INCREF(Py_None);
py_%(name)s = Py_None;
}
else if ((void*)py_%(name)s != (void*)%(name)s) {
py_%(name)s = (PyObject*)%(name)s;
}
py_{name} = Py_None;
}}
else if ((void*)py_{name} != (void*){name}) {{
py_{name} = (PyObject*){name};
}}
{Py_XINCREF(py_%(name)s);}
{{Py_XINCREF(py_{name});}}
if (%(name)s && !PyArray_ISALIGNED((PyArrayObject*) py_%(name)s)) {
if ({name} && !PyArray_ISALIGNED((PyArrayObject*) py_{name})) {{
PyErr_Format(PyExc_NotImplementedError,
"c_sync: expected an aligned array, got non-aligned array of type %%ld"
" with %%ld dimensions, with 3 last dims "
"%%ld, %%ld, %%ld"
" and 3 last strides %%ld %%ld, %%ld.",
(long int) PyArray_TYPE((PyArrayObject*) py_%(name)s),
(long int) PyArray_NDIM(%(name)s),
(long int) (PyArray_NDIM(%(name)s) >= 3 ?
PyArray_DIMS(%(name)s)[PyArray_NDIM(%(name)s)-3] : -1),
(long int) (PyArray_NDIM(%(name)s) >= 2 ?
PyArray_DIMS(%(name)s)[PyArray_NDIM(%(name)s)-2] : -1),
(long int) (PyArray_NDIM(%(name)s) >= 1 ?
PyArray_DIMS(%(name)s)[PyArray_NDIM(%(name)s)-1] : -1),
(long int) (PyArray_NDIM(%(name)s) >= 3 ?
PyArray_STRIDES(%(name)s)[PyArray_NDIM(%(name)s)-3] : -1),
(long int) (PyArray_NDIM(%(name)s) >= 2 ?
PyArray_STRIDES(%(name)s)[PyArray_NDIM(%(name)s)-2] : -1),
(long int) (PyArray_NDIM(%(name)s) >= 1 ?
PyArray_STRIDES(%(name)s)[PyArray_NDIM(%(name)s)-1] : -1)
"c_sync: expected an aligned array, got non-aligned array of type %ld"
" with %ld dimensions, with 3 last dims "
"%ld, %ld, %ld"
" and 3 last strides %ld %ld, %ld.",
(long int) PyArray_TYPE((PyArrayObject*) py_{name}),
(long int) PyArray_NDIM({name}),
(long int) (PyArray_NDIM({name}) >= 3 ?
PyArray_DIMS({name})[PyArray_NDIM({name})-3] : -1),
(long int) (PyArray_NDIM({name}) >= 2 ?
PyArray_DIMS({name})[PyArray_NDIM({name})-2] : -1),
(long int) (PyArray_NDIM({name}) >= 1 ?
PyArray_DIMS({name})[PyArray_NDIM({name})-1] : -1),
(long int) (PyArray_NDIM({name}) >= 3 ?
PyArray_STRIDES({name})[PyArray_NDIM({name})-3] : -1),
(long int) (PyArray_NDIM({name}) >= 2 ?
PyArray_STRIDES({name})[PyArray_NDIM({name})-2] : -1),
(long int) (PyArray_NDIM({name}) >= 1 ?
PyArray_STRIDES({name})[PyArray_NDIM({name})-1] : -1)
);
%(fail)s
}
""" % locals()
{fail}
}}
""".format(**locals())
def c_headers(self, **kwargs):
return ps.get_scalar_type(self.dtype).c_headers(**kwargs)
......
......@@ -103,12 +103,12 @@ class GetItem(COp):
output_name = out[0]
fail = sub["fail"]
return """
%(output_name)s = (typeof %(output_name)s) PyList_GetItem( (PyObject*) %(x_name)s, *((npy_int64 *) PyArray_DATA(%(index)s)));
if(%(output_name)s == NULL){
%(fail)s
}
Py_INCREF(%(output_name)s);
""" % locals()
{output_name} = (typeof {output_name}) PyList_GetItem( (PyObject*) {x_name}, *((npy_int64 *) PyArray_DATA({index})));
if({output_name} == NULL){{
{fail}
}}
Py_INCREF({output_name});
""".format(**locals())
def c_code_cache_version(self):
return (1,)
......@@ -170,8 +170,8 @@ class Append(COp):
fail = sub["fail"]
if not self.inplace:
init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ;
""" % locals()
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals())
else:
init = f"""
{output_name} = {x_name};
......@@ -179,15 +179,14 @@ class Append(COp):
return (
init
+ """
if(%(output_name)s==NULL){
%(fail)s
};
if(PyList_Append( (PyObject*) %(output_name)s,(PyObject*) %(toAppend)s)){
%(fail)s
};
Py_INCREF(%(output_name)s);
"""
% locals()
if({output_name}==NULL){{
{fail}
}};
if(PyList_Append( (PyObject*) {output_name},(PyObject*) {toAppend})){{
{fail}
}};
Py_INCREF({output_name});
""".format(**locals())
)
def c_code_cache_version(self):
......@@ -252,8 +251,8 @@ class Extend(COp):
fail = sub["fail"]
if not self.inplace:
init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ;
""" % locals()
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals())
else:
init = f"""
{output_name} = {x_name};
......@@ -262,18 +261,17 @@ class Extend(COp):
init
+ """
int i =0;
int length = PyList_GET_SIZE((PyObject*) %(toAppend)s);
if(%(output_name)s==NULL){
%(fail)s
};
for(i; i < length; i++){
if(PyList_Append( (PyObject*) %(output_name)s,(PyObject*) PyList_GetItem((PyObject*) %(toAppend)s,i))==-1){
%(fail)s
};
}
Py_INCREF(%(output_name)s);
"""
% locals()
int length = PyList_GET_SIZE((PyObject*) {toAppend});
if({output_name}==NULL){{
{fail}
}};
for(i; i < length; i++){{
if(PyList_Append( (PyObject*) {output_name},(PyObject*) PyList_GetItem((PyObject*) {toAppend},i))==-1){{
{fail}
}};
}}
Py_INCREF({output_name});
""".format(**locals())
)
def c_code_cache_version_(self):
......@@ -341,8 +339,8 @@ class Insert(COp):
fail = sub["fail"]
if not self.inplace:
init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ;
""" % locals()
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals())
else:
init = f"""
{output_name} = {x_name};
......@@ -350,15 +348,14 @@ class Insert(COp):
return (
init
+ """
if(%(output_name)s==NULL){
%(fail)s
};
if(PyList_Insert((PyObject*) %(output_name)s, *((npy_int64 *) PyArray_DATA(%(index)s)), (PyObject*) %(toInsert)s)==-1){
%(fail)s
};
Py_INCREF(%(output_name)s);
"""
% locals()
if({output_name}==NULL){{
{fail}
}};
if(PyList_Insert((PyObject*) {output_name}, *((npy_int64 *) PyArray_DATA({index})), (PyObject*) {toInsert})==-1){{
{fail}
}};
Py_INCREF({output_name});
""".format(**locals())
)
def c_code_cache_version(self):
......@@ -470,8 +467,8 @@ class Reverse(COp):
fail = sub["fail"]
if not self.inplace:
init = """
%(output_name)s = (PyListObject*) PyList_GetSlice((PyObject*) %(x_name)s, 0, PyList_GET_SIZE((PyObject*) %(x_name)s)) ;
""" % locals()
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals())
else:
init = f"""
{output_name} = {x_name};
......@@ -479,15 +476,14 @@ class Reverse(COp):
return (
init
+ """
if(%(output_name)s==NULL){
%(fail)s
};
if(PyList_Reverse((PyObject*) %(output_name)s)==-1){
%(fail)s
};
Py_INCREF(%(output_name)s);
"""
% locals()
if({output_name}==NULL){{
{fail}
}};
if(PyList_Reverse((PyObject*) {output_name})==-1){{
{fail}
}};
Py_INCREF({output_name});
""".format(**locals())
)
def c_code_cache_version(self):
......@@ -602,11 +598,11 @@ class Length(COp):
output_name = out[0]
fail = sub["fail"]
return """
if(!%(output_name)s)
%(output_name)s=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA(%(output_name)s))[0]=PyList_Size((PyObject*)%(x_name)s);
Py_INCREF(%(output_name)s);
""" % locals()
if(!{output_name})
{output_name}=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA({output_name}))[0]=PyList_Size((PyObject*){x_name});
Py_INCREF({output_name});
""".format(**locals())
def c_code_cache_version(self):
return (1,)
......
......@@ -111,26 +111,25 @@ class TypedListType(CType):
def c_extract(self, name, sub, check_input=True, **kwargs):
if check_input:
pre = """
if (!PyList_Check(py_%(name)s)) {
if (!PyList_Check(py_{name})) {{
PyErr_SetString(PyExc_TypeError, "expected a list");
%(fail)s
}""" % dict(name=name, fail=sub["fail"])
{fail}
}}""".format(**dict(name=name, fail=sub["fail"]))
else:
pre = ""
return (
pre
+ """
%(name)s = (PyListObject*) (py_%(name)s);
"""
% dict(name=name, fail=sub["fail"])
{name} = (PyListObject*) (py_{name});
""".format(**dict(name=name, fail=sub["fail"]))
)
def c_sync(self, name, sub):
return """
Py_XDECREF(py_%(name)s);
py_%(name)s = (PyObject*)(%(name)s);
Py_INCREF(py_%(name)s);
""" % dict(name=name)
Py_XDECREF(py_{name});
py_{name} = (PyObject*)({name});
Py_INCREF(py_{name});
""".format(**dict(name=name))
def c_cleanup(self, name, sub):
return ""
......
差异被折叠。
差异被折叠。
差异被折叠。
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论