提交 55b2f4fa authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Replace str.format with f-strings

上级 e2f9cb8e
......@@ -434,7 +434,7 @@ class OpFromGraph(Op, HasInnerGraph):
def __str__(self):
name = self.__class__.__name__ if self.name is None else self.name
is_inline = self.is_inline
return "{name}{{inline={is_inline}}}".format(**locals())
return f"{name}{{inline={is_inline}}}"
def _combine_list_overrides(self, default_outs, custom_outs, callable_args):
"""Combines default and custom overrides into a single list of outputs."""
......
......@@ -1890,11 +1890,7 @@ def get_info_on_inputs(named_inputs, n_unnamed_inputs):
)
else:
if n_unnamed_inputs == 0:
msg = "The function has {} named input{} ({}).".format(
n_named_inputs,
get_plural(n_named_inputs),
", ".join(named_inputs),
)
msg = f"The function has {n_named_inputs} named input{get_plural(n_named_inputs)} ({', '.join(named_inputs)})."
else:
msg = (
f"The function has {n_named_inputs} named input{get_plural(n_named_inputs)} ({', '.join(named_inputs)}), and {n_unnamed_inputs} unnamed "
......
......@@ -1664,15 +1664,12 @@ class PatternNodeRewriter(NodeRewriter):
def pattern_to_str(pattern):
if isinstance(pattern, list | tuple):
return "{}({})".format(
str(pattern[0]),
", ".join(pattern_to_str(p) for p in pattern[1:]),
)
args = ", ".join(pattern_to_str(p) for p in pattern[1:])
return f"{pattern[0]!s}({args})"
elif isinstance(pattern, dict):
return "{} subject to {}".format(
pattern_to_str(pattern["pattern"]),
str(pattern.get("constraint", "no conditions")),
)
a = pattern_to_str(pattern["pattern"])
b = pattern.get("constraint", "no conditions")
return f"{a} subject to {b}"
else:
return str(pattern)
......
......@@ -245,10 +245,9 @@ class MetaType(ABCMeta):
else:
def __str__(self):
return "{}{{{}}}".format(
self.__class__.__name__,
", ".join(f"{p}={getattr(self, p)!r}" for p in props),
)
classname = self.__class__.__name__
args = ", ".join(f"{p}={getattr(self, p)!r}" for p in props)
return f"{classname}{{{args}}}"
dct["__str__"] = __str__
......
......@@ -219,7 +219,6 @@ def struct_gen(args, struct_builders, blocks, sub):
"""
struct_decl = ""
struct_init_head = ""
struct_init_tail = ""
struct_cleanup = ""
for block in struct_builders:
......@@ -243,7 +242,6 @@ def struct_gen(args, struct_builders, blocks, sub):
# decrements the storage's refcount in the destructor
storage_decref = "\n".join(f"Py_XDECREF(this->{arg});" for arg in args)
args_names = ", ".join(args)
args_decl = ", ".join(f"PyObject* {arg}" for arg in args)
# The following code stores the exception data in __ERROR, which
......@@ -251,7 +249,8 @@ def struct_gen(args, struct_builders, blocks, sub):
# that holds the type, the value and the traceback. After storing
# the error, we return the failure code so we know which code
# block failed.
do_return = """
failure_var = sub["failure_var"]
do_return = f"""
if ({failure_var}) {{
// When there is a failure, this code puts the exception
// in __ERROR.
......@@ -274,15 +273,13 @@ def struct_gen(args, struct_builders, blocks, sub):
}}
// The failure code is returned to index what code block failed.
return {failure_var};
""".format(**sub)
sub = dict(sub)
sub.update(locals())
"""
# TODO: add some error checking to make sure storage_<x> are
# 1-element lists and __ERROR is a 3-elements list.
struct_code = """
name = sub["name"]
struct_code = f"""
namespace {{
struct {name} {{
PyObject* __ERROR;
......@@ -326,7 +323,7 @@ def struct_gen(args, struct_builders, blocks, sub):
}}
}};
}}
""".format(**sub)
"""
return struct_code
......@@ -370,13 +367,10 @@ def get_c_init(fgraph, r, name, sub):
Wrapper around c_init that initializes py_name to Py_None.
"""
pre = (
""
"""
pre = f"""
py_{name} = Py_None;
{{Py_XINCREF(py_{name});}}
""".format(**locals())
)
"""
return pre + r.type.c_init(name, sub)
......@@ -410,10 +404,10 @@ def get_c_extract(fgraph, r, name, sub):
else:
c_extract = r.type.c_extract(name, sub, False)
pre = """
pre = f"""
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
""".format(**locals())
"""
return pre + c_extract
......@@ -439,10 +433,10 @@ def get_c_extract_out(fgraph, r, name, sub):
else:
c_extract = r.type.c_extract_out(name, sub, check_input, check_broadcast=False)
pre = """
pre = f"""
py_{name} = PyList_GET_ITEM(storage_{name}, 0);
{{Py_XINCREF(py_{name});}}
""".format(**locals())
"""
return pre + c_extract
......@@ -451,9 +445,9 @@ def get_c_cleanup(fgraph, r, name, sub):
Wrapper around c_cleanup that decrefs py_name.
"""
post = """
post = f"""
{{Py_XDECREF(py_{name});}}
""".format(**locals())
"""
return r.type.c_cleanup(name, sub) + post
......@@ -462,7 +456,9 @@ def get_c_sync(fgraph, r, name, sub):
Wrapper around c_sync that syncs py_name with storage.
"""
return """
failure_var = sub["failure_var"]
sync = r.type.c_sync(name, sub)
return f"""
if (!{failure_var}) {{
{sync}
PyObject* old = PyList_GET_ITEM(storage_{name}, 0);
......@@ -470,7 +466,7 @@ def get_c_sync(fgraph, r, name, sub):
PyList_SET_ITEM(storage_{name}, 0, py_{name});
{{Py_XDECREF(old);}}
}}
""".format(**dict(sync=r.type.c_sync(name, sub), name=name, **sub))
"""
def apply_policy(fgraph, policy, r, name, sub):
......
......@@ -1966,14 +1966,14 @@ class Compiler:
return False
code = (
"""
f"""
{preamble}
int main(int argc, char** argv)
{{
{body}
return 0;
}}
""".format(**locals())
"""
).encode()
return cls._try_compile_tmp(
code,
......
......@@ -558,7 +558,9 @@ class CLinkerType(CLinkerObject):
uninitialized.
"""
return """
c_init_code = self.c_init(name, sub)
c_extract_code = self.c_extract(name, sub, check_input)
return f"""
if (py_{name} == Py_None)
{{
{c_init_code}
......@@ -567,13 +569,7 @@ class CLinkerType(CLinkerObject):
{{
{c_extract_code}
}}
""".format(
**dict(
name=name,
c_init_code=self.c_init(name, sub),
c_extract_code=self.c_extract(name, sub, check_input),
)
)
"""
def c_cleanup(self, name: str, sub: dict[str, str]) -> str:
"""Return C code to clean up after :meth:`CLinkerType.c_extract`.
......
......@@ -587,24 +587,15 @@ class ExternalCOp(COp):
params = f", {sub['params']}"
# Generate the C code
return """
return f"""
{define_macros}
{{
if ({func_name}({func_args}{params}) != 0) {{
{fail}
if ({self.func_name}({self.format_c_function_args(inp, out)}{params}) != 0) {{
{sub['fail']}
}}
}}
{undef_macros}
""".format(
**dict(
func_name=self.func_name,
fail=sub["fail"],
params=params,
func_args=self.format_c_function_args(inp, out),
define_macros=define_macros,
undef_macros=undef_macros,
)
)
"""
else:
if "code" in self.code_sections:
op_code = self.code_sections["code"]
......
......@@ -262,9 +262,10 @@ class Params(dict):
self.__dict__.update(__params_type__=params_type, __signatures__=None)
def __repr__(self):
return "Params({})".format(
", ".join((f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self))
args = ", ".join(
(f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self)
)
return f"Params({args})"
def __getattr__(self, key):
if key not in self:
......@@ -422,9 +423,10 @@ class ParamsType(CType):
return super().__getattr__(self, key)
def __repr__(self):
return "ParamsType<{}>".format(
", ".join((f"{self.fields[i]}:{self.types[i]}") for i in range(self.length))
args = ", ".join(
f"{self.fields[i]}:{self.types[i]}" for i in range(self.length)
)
return f"ParamsType<{args}>"
def __eq__(self, other):
return (
......@@ -730,11 +732,15 @@ class ParamsType(CType):
struct_init = "\n".join(c_init_list)
struct_cleanup = "\n".join(c_cleanup_list)
struct_extract = "\n\n".join(c_extract_list)
struct_extract_method = """
args = "\n".join(
f"case {i}: extract_{self.fields[i]}(object); break;"
for i in range(self.length)
)
struct_extract_method = f"""
void extract(PyObject* object, int field_pos) {{
switch(field_pos) {{
// Extraction cases.
{}
{args}
// Default case.
default:
PyErr_Format(PyExc_TypeError, "ParamsType: no extraction defined for a field %d.", field_pos);
......@@ -742,13 +748,8 @@ class ParamsType(CType):
break;
}}
}}
""".format(
"\n".join(
("case %d: extract_%s(object); break;" % (i, self.fields[i]))
for i in range(self.length)
)
)
final_struct_code = """
"""
final_struct_code = f"""
/** ParamsType {struct_name} **/
#ifndef {struct_name_defined}
#define {struct_name_defined}
......@@ -790,17 +791,7 @@ class ParamsType(CType):
}};
#endif
/** End ParamsType {struct_name} **/
""".format(
**dict(
struct_name_defined=struct_name_defined,
struct_name=struct_name,
struct_declare=struct_declare,
struct_init=struct_init,
struct_cleanup=struct_cleanup,
struct_extract=struct_extract,
struct_extract_method=struct_extract_method,
)
)
"""
return [*sorted(c_support_code_set), final_struct_code]
......@@ -813,9 +804,9 @@ class ParamsType(CType):
# pointers.
def c_declare(self, name, sub, check_input=True):
return """
{struct_name}* {name};
""".format(**dict(struct_name=self.name, name=name))
return f"""
{self.name}* {name};
"""
def c_init(self, name, sub):
# NB: It seems c_init() is not called for an op param.
......@@ -831,40 +822,33 @@ class ParamsType(CType):
"""
def c_extract(self, name, sub, check_input=True, **kwargs):
return """
fields_list = ", ".join(f'"{x}"' for x in self.fields)
return f"""
/* Seems c_init() is not called for a op param. So I call `new` here. */
{name} = new {struct_name};
{name} = new {self.name};
{{ // This need a separate namespace for Clinker
const char* fields[] = {{{fields_list}}};
if (py_{name} == Py_None) {{
PyErr_SetString(PyExc_ValueError, "ParamsType: expected an object, not None.");
{fail}
{sub['fail']}
}}
for (int i = 0; i < {length}; ++i) {{
for (int i = 0; i < {self.length}; ++i) {{
PyObject* o = PyDict_GetItemString(py_{name}, fields[i]);
if (o == NULL) {{
PyErr_Format(PyExc_TypeError, "ParamsType: missing expected attribute \\"%s\\" in object.", fields[i]);
{fail}
{sub['fail']}
}}
{name}->extract(o, i);
if ({name}->errorOccurred()) {{
/* The extract code from attribute type should have already raised a Python exception,
* so we just print the attribute name in stderr. */
fprintf(stderr, "\\nParamsType: error when extracting value for attribute \\"%s\\".\\n", fields[i]);
{fail}
{sub['fail']}
}}
}}
}}
""".format(
**dict(
name=name,
struct_name=self.name,
length=self.length,
fail=sub["fail"],
fields_list='"{}"'.format('", "'.join(self.fields)),
)
)
"""
def c_sync(self, name, sub):
# FIXME: Looks like we need to decrement a reference count our two.
......
......@@ -98,12 +98,12 @@ class Generic(CType, Singleton):
"""
def c_sync(self, name, sub):
return """
return f"""
assert(py_{name}->ob_refcnt > 1);
Py_DECREF(py_{name});
py_{name} = {name} ? {name} : Py_None;
Py_INCREF(py_{name});
""".format(**locals())
"""
def c_code_cache_version(self):
return (1,)
......@@ -190,18 +190,19 @@ class CDataType(CType[D]):
return data
def c_declare(self, name, sub, check_input=True):
return """
{ctype} {name};
""".format(**dict(ctype=self.ctype, name=name))
return f"""
{self.ctype} {name};
"""
def c_init(self, name, sub):
return f"{name} = NULL;"
def c_extract(self, name, sub, check_input=True, **kwargs):
return """
{name} = ({ctype})PyCapsule_GetPointer(py_{name}, NULL);
fail = sub["fail"]
return f"""
{name} = ({self.ctype})PyCapsule_GetPointer(py_{name}, NULL);
if ({name} == NULL) {fail}
""".format(**dict(name=name, ctype=self.ctype, fail=sub["fail"]))
"""
def c_sync(self, name, sub):
freefunc = self.freefunc
......@@ -478,11 +479,8 @@ class EnumType(CType, dict):
names_to_aliases = {constant_name: "" for constant_name in self}
for alias in self.aliases:
names_to_aliases[self.aliases[alias]] = f"({alias})"
return "{}<{}>({})".format(
type(self).__name__,
self.ctype,
", ".join(f"{k}{names_to_aliases[k]}:{self[k]}" for k in sorted(self)),
)
args = ", ".join(f"{k}{names_to_aliases[k]}:{self[k]}" for k in sorted(self))
return f"{type(self).__name__}<{self.ctype}>({args})"
def __getattr__(self, key):
if key in self:
......@@ -575,33 +573,27 @@ class EnumType(CType, dict):
This C function may be useful to retrieve some runtime information.
It is available in C code when pytensor flag ``config.cmodule__debug`` is set to ``True``.
"""
return """
cases = "".join(
f"""
case {name}: sprintf(out, "{name}"); break;
"""
for name in self
)
return f"""
#ifdef DEBUG
int pytensor_enum_to_string_{cname}({ctype} in, char* out) {{
int pytensor_enum_to_string_{self.cname}({self.ctype} in, char* out) {{
int ret = 0;
switch(in) {{
{cases}
default:
PyErr_SetString(PyExc_ValueError, "{classname}: unknown enum value.");
PyErr_SetString(PyExc_ValueError, "{type(self).__name__}: unknown enum value.");
ret = -1;
break;
}}
return ret;
}}
#endif
""".format(
**dict(
cname=self.cname,
ctype=self.ctype,
classname=type(self).__name__,
cases="".join(
"""
case {name}: sprintf(out, "{name}"); break;
""".format(**dict(name=name))
for name in self
),
)
)
"""
def c_support_code(self, **kwargs):
return (
......@@ -625,16 +617,17 @@ class EnumType(CType, dict):
return ""
def c_extract(self, name, sub, check_input=True, **kwargs):
return """
fail = sub["fail"]
return f"""
if (PyInt_Check(py_{name})) {{
{name} = ({ctype})PyInt_AsLong(py_{name});
{name} = ({self.ctype})PyInt_AsLong(py_{name});
}} else {{
{name} = ({ctype})PyFloat_AsDouble(py_{name});
{name} = ({self.ctype})PyFloat_AsDouble(py_{name});
}}
if (PyErr_Occurred()) {{
{fail}
}}
""".format(**dict(ctype=self.ctype, name=name, fail=sub["fail"]))
"""
def c_code_cache_version(self):
return (2, self.ctype, self.cname, tuple(self.items()))
......@@ -754,7 +747,14 @@ class CEnumType(EnumList):
swapped_dict = {v: k for (k, v) in self.items()}
# swapped_dict's keys are integers.
return """
fail = sub["fail"]
cases = "".join(
f"""
case {i}: {name} = {swapped_dict[i]}; break;
"""
for i in sorted(swapped_dict)
)
return f"""
switch(PyInt_AsLong(py_{name})) {{
{cases}
default:
......@@ -762,19 +762,7 @@ class CEnumType(EnumList):
{{{fail}}}
break;
}}
""".format(
**dict(
name=name,
cases="".join(
"""
case %(i)d: %(name)s = %(constant_cname)s; break;
"""
% dict(i=i, name=name, constant_cname=swapped_dict[i])
for i in sorted(swapped_dict)
),
fail=sub["fail"],
)
)
"""
def c_code_cache_version(self):
return (1, super().c_code_cache_version())
......@@ -136,20 +136,19 @@ def _check_scipy_linalg_matrix(a, func_name):
Adapted from https://github.com/numba/numba/blob/bd7ebcfd4b850208b627a3f75d4706000be36275/numba/np/linalg.py#L831
"""
prefix = "scipy.linalg"
interp = (prefix, func_name)
# Unpack optional type
if isinstance(a, types.Optional):
a = a.type
if not isinstance(a, types.Array):
msg = "{}.{}() only supported for array types".format(*interp)
msg = f"{prefix}.{func_name}() only supported for array types"
raise numba.TypingError(msg, highlighting=False)
if a.ndim not in [1, 2]:
msg = "{}.{}() only supported on 1d or 2d arrays, found {}.".format(
*interp, a.ndim
msg = (
f"{prefix}.{func_name}() only supported on 1d or 2d arrays, found {a.ndim}."
)
raise numba.TypingError(msg, highlighting=False)
if not isinstance(a.dtype, types.Float | types.Complex):
msg = "{}.{}() only supported on " "float and complex arrays.".format(*interp)
msg = "{prefix}.{func_name}() only supported on float and complex arrays."
raise numba.TypingError(msg, highlighting=False)
......
......@@ -355,14 +355,10 @@ def raise_with_op(
+ f"\nInputs values: {scalar_values}"
)
if verbosity == "high":
detailed_err_msg += "\nInputs type_num: {}".format(
str(
[
getattr(getattr(i[0], "dtype", ""), "num", "")
for i in thunk.inputs
]
)
)
inpts = [
getattr(getattr(i[0], "dtype", ""), "num", "") for i in thunk.inputs
]
detailed_err_msg += f"\nInputs type_num: {inpts}"
detailed_err_msg += f"\nOutputs clients: {clients}\n"
else:
......
......@@ -1048,10 +1048,8 @@ class DefaultPrinter(Printer):
if node is None:
return leaf_printer.process(output, pstate)
with set_precedence(pstate):
r = "{}({})".format(
str(node.op),
", ".join(pprinter.process(input, pstate) for input in node.inputs),
)
args = ", ".join(pprinter.process(input, pstate) for input in node.inputs)
r = f"{node.op}({args})"
pstate.memo[output] = r
return r
......
......@@ -464,33 +464,32 @@ class ScalarType(CType, HasDataType, HasShape):
raise NotImplementedError("float16")
specs = self.dtype_specs()
if check_input:
pre = """
fail = sub["fail"]
dtype = specs[1]
pyarr_type = f"Py{specs[2]}ArrType_Type"
pre = f"""
if (!PyObject_TypeCheck(py_{name}, &{pyarr_type}))
{{
PyErr_Format(PyExc_ValueError,
"Scalar check failed ({dtype})");
{fail}
}}
""".format(
**dict(
sub,
name=name,
dtype=specs[1],
pyarr_type=f"Py{specs[2]}ArrType_Type",
)
)
"""
else:
pre = ""
return (
pre
+ """
+ f"""
PyArray_ScalarAsCtype(py_{name}, &{name});
""".format(**dict(sub, name=name))
"""
)
def c_sync(self, name, sub):
specs = self.dtype_specs()
return """
fail = sub["fail"]
dtype = specs[1]
cls = specs[2]
return f"""
Py_XDECREF(py_{name});
py_{name} = PyArrayScalar_New({cls});
if (!py_{name})
......@@ -502,7 +501,7 @@ class ScalarType(CType, HasDataType, HasShape):
{fail}
}}
PyArrayScalar_ASSIGN(py_{name}, {cls}, {name});
""".format(**dict(sub, name=name, dtype=specs[1], cls=specs[2]))
"""
def c_cleanup(self, name, sub):
return ""
......@@ -1179,10 +1178,9 @@ class ScalarOp(COp):
not in ("name", "_op_use_c_code", "bool", "output_types_preference")
]
if param:
return "{}{{{}}}".format(
self.__class__.__name__,
", ".join(f"{k}={v}" for k, v in param),
)
classname = self.__class__.__name__
args = ", ".join(f"{k}={v}" for k, v in param)
return f"{classname}{{{args}}}"
else:
return self.__class__.__name__
......
......@@ -615,8 +615,8 @@ class Chi2SF(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
return """{z} =
({dtype}) 1 - GammaP({k}/2., {x}/2.);""".format(**locals())
return f"""{z} =
({dtype}) 1 - GammaP({k}/2., {x}/2.);"""
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
......@@ -666,8 +666,8 @@ class GammaInc(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
return """{z} =
({dtype}) GammaP({k}, {x});""".format(**locals())
return f"""{z} =
({dtype}) GammaP({k}, {x});"""
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
......@@ -717,8 +717,8 @@ class GammaIncC(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
return """{z} =
({dtype}) GammaQ({k}, {x});""".format(**locals())
return f"""{z} =
({dtype}) GammaQ({k}, {x});"""
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
......@@ -1028,8 +1028,8 @@ class GammaU(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
return """{z} =
({dtype}) upperGamma({k}, {x});""".format(**locals())
return f"""{z} =
({dtype}) upperGamma({k}, {x});"""
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
......@@ -1064,8 +1064,8 @@ class GammaL(BinaryScalarOp):
(z,) = out
if node.inputs[0].type in float_types:
dtype = "npy_" + node.outputs[0].dtype
return """{z} =
({dtype}) lowerGamma({k}, {x});""".format(**locals())
return f"""{z} =
({dtype}) lowerGamma({k}, {x});"""
raise NotImplementedError("only floatingpoint is implemented")
def __eq__(self, other):
......
......@@ -1072,7 +1072,8 @@ class ScanArgs:
for p in self.field_names
if p.startswith("outer_out")
]
res = "ScanArgs(\n{})".format(",\n".join(inner_arg_strs))
args = ",\n".join(inner_arg_strs)
res = f"ScanArgs(\n{args})"
return res
def __repr__(self):
......
......@@ -3617,7 +3617,8 @@ class StructuredDotGradCSC(COp):
if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for g_ab")
return """
fail = sub["fail"]
return f"""
if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}}
if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}}
if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}}
......@@ -3689,7 +3690,7 @@ class StructuredDotGradCSC(COp):
}}
}}
""".format(**dict(locals(), **sub))
"""
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
......@@ -3750,7 +3751,8 @@ class StructuredDotGradCSR(COp):
if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for g_ab")
return """
fail = sub["fail"]
return f"""
if (PyArray_NDIM({_d}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(d) != 2"); {fail};}}
if (PyArray_NDIM({_g}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(g) != 2"); {fail};}}
if (PyArray_NDIM({_indices}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(indices) != 1"); {fail};}}
......@@ -3823,7 +3825,7 @@ class StructuredDotGradCSR(COp):
}}
}}
""".format(**dict(locals(), **sub))
"""
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
......
......@@ -137,7 +137,7 @@ class AddSD_ccode(_NoPythonCOp):
inplace = int(self.inplace)
format = {"csc": 0, "csr": 1}[self.format]
out_typenum = node.outputs[0].type.dtype_specs()[2]
code = """
code = f"""
Py_XDECREF({z});
if (!{inplace}){{
if(PyArray_TYPE({y}) != {out_typenum}){{
......@@ -179,7 +179,7 @@ class AddSD_ccode(_NoPythonCOp):
}}
}}
}}
""".format(**dict(locals(), **sub))
"""
return code
def infer_shape(self, fgraph, node, shapes):
......@@ -310,7 +310,8 @@ class StructuredDotCSC(COp):
typenum_a_val = node.inputs[0].type.dtype_specs()[2] # retrieve dtype number
typenum_b = node.inputs[4].type.dtype_specs()[2] # retrieve dtype number
rval = """
fail = sub["fail"]
rval = f"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
......@@ -430,7 +431,7 @@ class StructuredDotCSC(COp):
}}
}}
}}
""".format(**dict(locals(), **sub))
"""
return rval
......@@ -517,7 +518,8 @@ class StructuredDotCSR(COp):
if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b")
return """
fail = sub["fail"]
return f"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}}
......@@ -609,7 +611,7 @@ class StructuredDotCSR(COp):
}}
}}
""".format(**dict(locals(), **sub))
"""
def c_code_cache_version(self):
return (2,)
......@@ -756,7 +758,8 @@ class UsmmCscDense(_NoPythonCOp):
inplace = int(self.inplace)
rval = """
fail = sub["fail"]
rval = f"""
if (PyArray_NDIM({x_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_val) != 1"); {fail};}}
if (PyArray_NDIM({x_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(x_ind) != 1"); {fail};}}
......@@ -888,7 +891,7 @@ class UsmmCscDense(_NoPythonCOp):
}}
}}
}}
""".format(**dict(locals(), **sub))
"""
return rval
......@@ -985,7 +988,8 @@ class CSMGradC(_NoPythonCOp):
if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b_val")
return """
fail = sub["fail"]
return f"""
if (PyArray_NDIM({a_val}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_val) != 1"); {fail};}}
if (PyArray_NDIM({a_ind}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ind) != 1"); {fail};}}
if (PyArray_NDIM({a_ptr}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a_ptr) != 1"); {fail};}}
......@@ -1079,7 +1083,7 @@ class CSMGradC(_NoPythonCOp):
}}
}}
""".format(**dict(locals(), **sub))
"""
def c_code_cache_version(self):
return (3,)
......@@ -1165,7 +1169,8 @@ class MulSDCSC(_NoPythonCOp):
if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b")
return """
fail = sub["fail"]
return f"""
if (PyArray_NDIM({_b}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
{fail};}}
......@@ -1231,7 +1236,7 @@ class MulSDCSC(_NoPythonCOp):
}}
}}
""".format(**dict(locals(), **sub))
"""
def __str__(self):
return self.__class__.__name__
......@@ -1302,7 +1307,8 @@ class MulSDCSR(_NoPythonCOp):
if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b")
return """
fail = sub["fail"]
return f"""
if (PyArray_NDIM({_b}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2");
{fail};}}
......@@ -1368,7 +1374,7 @@ class MulSDCSR(_NoPythonCOp):
}}
}}
""".format(**dict(locals(), **sub))
"""
def __str__(self):
return self.__class__.__name__
......@@ -1491,7 +1497,8 @@ class MulSVCSR(_NoPythonCOp):
if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b")
return """
fail = sub["fail"]
return f"""
if (PyArray_NDIM({_b}) != 1) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
{fail};
......@@ -1553,7 +1560,7 @@ class MulSVCSR(_NoPythonCOp):
}}
}}
""".format(**dict(locals(), **sub))
"""
def __str__(self):
return self.__class__.__name__
......@@ -1663,7 +1670,8 @@ class StructuredAddSVCSR(_NoPythonCOp):
if node.inputs[3].type.dtype in ("complex64", "complex128"):
raise NotImplementedError("Complex types are not supported for b")
return """
fail = sub["fail"]
return f"""
if (PyArray_NDIM({_b}) != 1) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1");
{fail};
......@@ -1732,7 +1740,7 @@ class StructuredAddSVCSR(_NoPythonCOp):
}}
}}
""".format(**dict(locals(), **sub))
"""
def __str__(self):
return self.__class__.__name__
......@@ -1904,7 +1912,8 @@ class SamplingDotCSR(_NoPythonCOp):
typenum_zi = TensorType(node.outputs[1].dtype, []).dtype_specs()[2]
typenum_zp = TensorType(node.outputs[2].dtype, []).dtype_specs()[2]
rval = """
fail = sub["fail"]
rval = f"""
if (PyArray_NDIM({x}) != 2) {{
PyErr_SetString(PyExc_NotImplementedError, "rank(x) != 2"); {fail};}}
if (PyArray_NDIM({y}) != 2) {{
......@@ -2024,7 +2033,7 @@ PyErr_SetString(PyExc_NotImplementedError, "rank(y) != 2"); {fail};}}
}}
}}
}}
""".format(**dict(locals(), **sub))
"""
return rval
......
......@@ -597,12 +597,12 @@ class TensorFromScalar(COp):
(z,) = outputs
fail = sub["fail"]
return """
return f"""
{z} = (PyArrayObject*)PyArray_FromScalar(py_{x}, NULL);
if({z} == NULL){{
{fail};
}}
""".format(**locals())
"""
def c_code_cache_version(self):
return (2,)
......@@ -646,10 +646,9 @@ class ScalarFromTensor(COp):
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
(z,) = outputs
fail = sub["fail"]
return """
return f"""
{z} = ((dtype_{x}*)(PyArray_DATA({x})))[0];
""".format(**locals())
"""
def c_code_cache_version(self):
return (1,)
......@@ -1836,18 +1835,18 @@ class MakeVector(COp):
assert self.dtype == node.inputs[0].dtype
out_num = f"PyArray_TYPE({inp[0]})"
ret = """
ret = f"""
npy_intp dims[1];
dims[0] = {out_shape};
if(!{out} || PyArray_DIMS({out})[0] != {out_shape}){{
Py_XDECREF({out});
{out} = (PyArrayObject*)PyArray_EMPTY(1, dims, {out_num}, 0);
}}
""".format(**locals())
"""
for idx, i in enumerate(inp):
ret += """
ret += f"""
*(({out_dtype} *)PyArray_GETPTR1({out}, {idx})) = *(({out_dtype} *) PyArray_DATA({i}));
""".format(**locals())
"""
return ret
def infer_shape(self, fgraph, node, ishapes):
......@@ -2225,7 +2224,7 @@ class Split(COp):
splits_dtype = node.inputs[2].type.dtype_specs()[1]
expected_splits_count = self.len_splits
return """
return f"""
int ndim = PyArray_NDIM({x});
int axis = (int)(*({axis_dtype}*)PyArray_GETPTR1({axis}, 0));
int splits_count = PyArray_DIM({splits}, 0);
......@@ -2322,7 +2321,7 @@ class Split(COp):
}}
free(split_dims);
""".format(**locals())
"""
class Join(COp):
......@@ -2373,10 +2372,9 @@ class Join(COp):
if self.view == -1:
return self.__class__.__name__
else:
return "{}{{{}}}".format(
self.__class__.__name__,
", ".join(f"{p}={getattr(self, p)!r}" for p in self.__props__),
)
classname = self.__class__.__name__
args = ", ".join(f"{p}={getattr(self, p)!r}" for p in self.__props__)
return f"{classname}{{{args}}}"
def __setstate__(self, d):
self.__dict__.update(d)
......@@ -2538,7 +2536,7 @@ class Join(COp):
copy_inputs_to_list = "\n".join(copy_to_list)
n = len(tens)
code = """
code = f"""
int axis = (({adtype} *)PyArray_DATA({axis}))[0];
PyObject* list = PyList_New({l});
{copy_inputs_to_list}
......@@ -2570,7 +2568,7 @@ class Join(COp):
{fail}
}}
}}
""".format(**locals())
"""
return code
def R_op(self, inputs, eval_points):
......@@ -4192,18 +4190,14 @@ class AllocEmpty(COp):
params = sub["params"]
str = f"npy_intp dims[{nd}];\n"
for idx, sh in enumerate(shps):
str += (
"dims[{idx}] ="
"((npy_intp)((dtype_{sh}*)"
" PyArray_DATA({sh}))[0]);\n".format(**locals())
)
str += f"dims[{idx}] = ((npy_intp)((dtype_{sh}*) PyArray_DATA({sh}))[0]);\n"
# Validate that the output storage exists
str += f"if({out}==NULL\n"
for idx, sh in enumerate(shps):
str += f"||PyArray_DIMS({out})[{idx}]!=dims[{idx}]"
str += """){{
str += f"""){{
/* Reference received to invalid output variable.
Decrease received reference's ref count and allocate new
output variable */
......@@ -4218,7 +4212,7 @@ class AllocEmpty(COp):
{fail};
}}
}}
""".format(**locals())
"""
return str
def infer_shape(self, fgraph, node, input_shapes):
......
......@@ -1806,19 +1806,12 @@ class BatchedDot(COp):
strides = f"PyArray_STRIDES({var})"
if ndim == 1:
return f"{strides}[0] == type_size"
return " && ".join(
[
" && ".join(
f"{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0"
for i in range(1, ndim)
),
"({})".format(
" || ".join(
f"{strides}[{i}] == type_size" for i in range(1, ndim)
)
),
]
ands = " && ".join(
f"{strides}[{i}] > 0 && {strides}[{i}] % type_size == 0"
for i in range(1, ndim)
)
ors = " || ".join(f"{strides}[{i}] == type_size" for i in range(1, ndim))
return f"{ands} && ({ors})"
x_ndim, y_ndim, z_ndim = (
node.inputs[0].ndim,
......@@ -1838,7 +1831,7 @@ class BatchedDot(COp):
)
z_shape = ", ".join(z_dims)
z_contiguous = contiguous(_z, z_ndim)
allocate = """
allocate = f"""
if (NULL == {_z} || !({z_shape_correct}) || !({z_contiguous}))
{{
npy_intp dims[{z_ndim}] = {{{z_shape}}};
......@@ -1851,14 +1844,14 @@ class BatchedDot(COp):
{fail}
}}
}}
""".format(**locals())
"""
# code to reallocate inputs contiguously if necessary
contiguate = []
for var, ndim in [(_x, x_ndim), (_y, y_ndim)]:
_contiguous = contiguous(var, ndim)
contiguate.append(
"""
f"""
if (!({_contiguous})) {{
PyArrayObject * _copy = (PyArrayObject *) PyArray_Copy({var});
if (!_copy)
......@@ -1866,11 +1859,11 @@ class BatchedDot(COp):
Py_XDECREF({var});
{var} = _copy;
}}
""".format(**locals())
"""
)
contiguate = "\n".join(contiguate)
return """
return f"""
int type_num = PyArray_DESCR({_x})->type_num;
int type_size = PyArray_DESCR({_x})->elsize; // in bytes
......@@ -1927,7 +1920,7 @@ class BatchedDot(COp):
}}
break;
}}
""".format(**locals())
"""
def c_code_cache_version(self):
from pytensor.tensor.blas_headers import blas_header_version
......
......@@ -33,7 +33,7 @@ class BaseBLAS(COp):
def ger_c_code(A, a, x, y, Z, fail, params):
return """
return f"""
int elemsize ;
......@@ -309,7 +309,7 @@ def ger_c_code(A, a, x, y, Z, fail, params):
}}
}}
""".format(**locals())
"""
class CGer(BaseBLAS, Ger):
......
......@@ -1066,7 +1066,7 @@ def blas_header_version():
def ____gemm_code(check_ab, a_init, b_init):
mod = "%"
return """
return f"""
const char * error_string = NULL;
int type_num = PyArray_DESCR(_x)->type_num;
......@@ -1203,4 +1203,4 @@ def ____gemm_code(check_ab, a_init, b_init):
return -1;
/* v 1 */
""".format(**locals())
"""
......@@ -302,10 +302,7 @@ class DimShufflePrinter(Printer):
return pstate.pprinter.process(r)
if list(new_order) == list(reversed(range(r.type.ndim))):
return f"{pstate.pprinter.process(r)}.T"
return "DimShuffle{{{}}}({})".format(
", ".join(map(str, new_order)),
pstate.pprinter.process(r),
)
return f"DimShuffle{{{', '.join(str(o) for o in new_order)}}}({pstate.pprinter.process(r)})"
def process(self, r, pstate):
if r.owner is None:
......@@ -929,13 +926,13 @@ class Elemwise(OpenMPOp):
# We make the output point to the corresponding input and
# decrease the reference of whatever the output contained
# prior to this
alloc += """
alloc += f"""
if ({oname}) {{
Py_XDECREF({oname});
}}
{oname} = {iname};
Py_XINCREF({oname});
""".format(**locals())
"""
# We alias the scalar variables
defines += f"#define {oname}_i {iname}_i\n"
undefs += f"#undef {oname}_i\n"
......@@ -958,13 +955,13 @@ class Elemwise(OpenMPOp):
[f"{s}_i" for s in onames],
dict(sub, fail=fail),
)
code = """
code = f"""
{{
{defines}
{task_code}
{undefs}
}}
""".format(**locals())
"""
loop_orders = orders + [list(range(nnested))] * len(real_onames)
dtypes = idtypes + list(real_odtypes)
......@@ -994,19 +991,17 @@ class Elemwise(OpenMPOp):
if index != "x":
preloops.setdefault(j, "")
preloops[j] += (
"%(lv{i})s_iter = ({dtype}*)"
"(PyArray_DATA(%(lv{i})s));\n".format(**locals())
f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n"
) % sub
break
else: # all broadcastable
preloops.setdefault(0, "")
preloops[0] += (
"%(lv{i})s_iter = ({dtype}*)"
"(PyArray_DATA(%(lv{i})s));\n".format(**locals())
f"%(lv{i})s_iter = ({dtype}*)(PyArray_DATA(%(lv{i})s));\n"
) % sub
init_array = preloops.get(0, " ")
loop = """
loop = f"""
{{
{defines}
{init_array}
......@@ -1014,7 +1009,7 @@ class Elemwise(OpenMPOp):
{task_code}
{undefs}
}}
""".format(**locals())
"""
else:
loop = cgen.make_loop(
loop_orders=loop_orders,
......@@ -1074,25 +1069,25 @@ class Elemwise(OpenMPOp):
index = ""
for x, var in zip(inames + onames, inputs + node.outputs):
if not all(s == 1 for s in var.type.shape):
contig += """
contig += f"""
dtype_{x} * {x}_ptr = (dtype_{x}*) PyArray_DATA({x});
""".format(**locals())
index += """
"""
index += f"""
dtype_{x}& {x}_i = {x}_ptr[i];
""".format(**locals())
"""
else:
contig += """
contig += f"""
dtype_{x}& {x}_i = ((dtype_{x}*) PyArray_DATA({x}))[0];
""".format(**locals())
"""
if self.openmp:
contig += f"""#pragma omp parallel for if(n>={int(config.openmp_elemwise_minsize)})
"""
contig += """
contig += f"""
for(int i=0; i<n; i++){{
{index}
{task_code};
}}
""".format(**locals())
"""
if contig is not None:
z = list(zip(inames + onames, inputs + node.outputs))
all_broadcastable = all(s == 1 for s in var.type.shape)
......@@ -1106,13 +1101,13 @@ class Elemwise(OpenMPOp):
for arr, var in z
if not all_broadcastable
)
loop = """
loop = f"""
if(({cond1}) || ({cond2})){{
{contig}
}}else{{
{loop}
}}
""".format(**locals())
"""
return decl, checks, alloc, loop, ""
def c_code(self, node, nodename, inames, onames, sub):
......
......@@ -170,12 +170,14 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX")
nd = len(loop_orders[0])
init_dims = compute_output_dims_lengths("dims", loop_orders, sub)
olv = sub["olv"]
fail = sub["fail"]
# TODO: it would be interesting to allocate the output in such a
# way that its contiguous dimensions match one of the input's
# contiguous dimensions, or the dimension with the smallest
# stride. Right now, it is allocated to be C_CONTIGUOUS.
return """
return f"""
{{
npy_intp dims[{nd}];
//npy_intp* dims = (npy_intp*)malloc({nd} * sizeof(npy_intp));
......@@ -203,7 +205,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
{fail}
}}
}}
""".format(**dict(locals(), **sub))
"""
def make_loop(loop_orders, dtypes, loop_tasks, sub, openmp=None):
......
......@@ -68,7 +68,7 @@ class CpuContiguous(COp):
def c_code(self, node, name, inames, onames, sub):
(x,) = inames
(y,) = onames
code = """
code = f"""
if (!PyArray_CHKFLAGS({x}, NPY_ARRAY_C_CONTIGUOUS)){{
// check to see if output is contiguous first
if ({y} != NULL &&
......@@ -86,7 +86,7 @@ class CpuContiguous(COp):
Py_XDECREF({y});
{y} = {x};
}}
""".format(**locals())
"""
return code
def c_code_cache_version(self):
......@@ -161,13 +161,13 @@ class SearchsortedOp(COp):
def c_init_code_struct(self, node, name, sub):
side = sub["params"]
fail = sub["fail"]
return """
return f"""
PyObject* tmp_{name} = PyUnicode_FromString("right");
if (tmp_{name} == NULL)
{fail};
right_{name} = PyUnicode_Compare({side}, tmp_{name});
Py_DECREF(tmp_{name});
""".format(**locals())
"""
def c_code(self, node, name, inames, onames, sub):
sorter = None
......@@ -180,7 +180,7 @@ class SearchsortedOp(COp):
(z,) = onames
fail = sub["fail"]
return """
return f"""
Py_XDECREF({z});
{z} = (PyArrayObject*) PyArray_SearchSorted({x}, (PyObject*) {v},
right_{name} ? NPY_SEARCHLEFT : NPY_SEARCHRIGHT, (PyObject*) {sorter});
......@@ -191,7 +191,7 @@ class SearchsortedOp(COp):
Py_XDECREF({z});
{z} = (PyArrayObject*) tmp;
}}
""".format(**locals())
"""
def c_code_cache_version(self):
return (2,)
......@@ -348,11 +348,10 @@ class CumOp(COp):
def c_code(self, node, name, inames, onames, sub):
(x,) = inames
(z,) = onames
axis = self.axis
fail = sub["fail"]
params = sub["params"]
code = """
code = f"""
int axis = {params}->c_axis;
if (axis == 0 && PyArray_NDIM({x}) == 1)
axis = NPY_MAXDIMS;
......@@ -389,7 +388,7 @@ class CumOp(COp):
// Because PyArray_CumSum/CumProd returns a newly created reference on t.
Py_XDECREF(t);
}}
""".format(**locals())
"""
return code
......
......@@ -215,14 +215,14 @@ class Argmax(COp):
if len(self.axis) != 1:
raise NotImplementedError()
# params is only used here for now
axis_code = """
axis_code = f"""
axis = {params}->c_axis;
if(axis > PyArray_NDIM({x})-1 || axis < -PyArray_NDIM({x})){{
PyErr_SetString(PyExc_ValueError,
"Argmax, bad axis argument");
{fail}
}}
""".format(**locals())
"""
ret = """
int axis;
......@@ -1314,7 +1314,8 @@ class Mean(FixedOpCAReduce):
def __str__(self):
if self.axis is not None:
return "Mean{{{}}}".format(", ".join(str(x) for x in self.axis))
args = ", ".join(str(x) for x in self.axis)
return f"Mean{{{args}}}"
else:
return "Mean"
......
......@@ -1137,7 +1137,7 @@ class Subtensor(COp):
"""
rval += """
rval += f"""
// One more argument of the view
npy_intp xview_offset = 0;
......@@ -1264,7 +1264,7 @@ class Subtensor(COp):
inner_ii += 1;
outer_ii += 1;
}}
""".format(**locals())
"""
# print rval
return rval
......@@ -1284,19 +1284,19 @@ class Subtensor(COp):
decl = "PyArrayObject * xview = NULL;"
checkNDim = """
checkNDim = f"""
if (PyArray_NDIM({x}) != {ndim}){{
PyErr_SetString(PyExc_ValueError,
"Expected {ndim} dimensions input"
);
{fail}
}}
""".format(**locals())
"""
get_xview = self.helper_c_code(
node, name, inputs, outputs, sub, self.idx_list, view_ndim
)
build_view = """
build_view = f"""
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
Py_INCREF(PyArray_DESCR({x}));
......@@ -1314,7 +1314,7 @@ class Subtensor(COp):
{{
{fail};
}}
""".format(**locals())
"""
finish_view = f"""
Py_XDECREF({z});
......@@ -1761,7 +1761,7 @@ class IncSubtensor(COp):
copy_of_x = self.copy_of_x(x)
copy_input_if_necessary = """
copy_input_if_necessary = f"""
if ({inplace})
{{
if ({x} != {z})
......@@ -1780,7 +1780,7 @@ class IncSubtensor(COp):
{fail}
}}
}}
""".format(**locals())
"""
# get info needed to make zview: a view of %(z)s
helper_args = self.get_helper_c_code_args()
......@@ -1799,7 +1799,7 @@ class IncSubtensor(COp):
# Make a view on the output, as we will write into it.
alloc_zview = self.make_view_array(z, view_ndim)
build_view = """
build_view = f"""
//TODO: give this Op a second output so that this view can be cached
//TODO: alternatively, fix the memory leak on failure
{alloc_zview};
......@@ -1807,13 +1807,13 @@ class IncSubtensor(COp):
{{
{fail};
}}
""".format(**locals())
"""
copy_into = self.copy_into("zview", y)
add_to_zview = self.add_to_zview(name, y, fail)
make_modification = """
make_modification = f"""
if ({op_is_set})
{{
if ({copy_into}) // does broadcasting
......@@ -1826,7 +1826,7 @@ class IncSubtensor(COp):
{{
{add_to_zview}
}}
""".format(**locals())
"""
return (
self.decl_view()
+ copy_input_if_necessary
......@@ -1896,7 +1896,7 @@ class IncSubtensor(COp):
"""
return """Py_INCREF(PyArray_DESCR({x}));
return f"""Py_INCREF(PyArray_DESCR({x}));
zview = (PyArrayObject*)PyArray_NewFromDescr(
&PyArray_Type,
PyArray_DESCR({x}),
......@@ -1906,7 +1906,7 @@ class IncSubtensor(COp):
PyArray_BYTES({x}) + xview_offset, //PyArray_DATA({x}),
PyArray_FLAGS({x}),
NULL);
""".format(**locals())
"""
def get_helper_c_code_args(self):
"""
......@@ -1939,7 +1939,7 @@ class IncSubtensor(COp):
"""
return """
return f"""
PyArrayObject * add_rval = (PyArrayObject*)PyNumber_InPlaceAdd(
(PyObject*)zview, py_{x});
if (add_rval)
......@@ -1952,7 +1952,7 @@ class IncSubtensor(COp):
{{
Py_DECREF(zview);
{fail};
}}""".format(**locals())
}}"""
def infer_shape(self, fgraph, node, shapes):
return [shapes[0]]
......@@ -2162,7 +2162,7 @@ class AdvancedSubtensor1(COp):
a_name, i_name = input_names[0], input_names[1]
output_name = output_names[0]
fail = sub["fail"]
return """
return f"""
PyArrayObject *indices;
int i_type = PyArray_TYPE({i_name});
if (i_type != NPY_INTP) {{
......@@ -2237,7 +2237,7 @@ class AdvancedSubtensor1(COp):
{a_name}, (PyObject*)indices, 0, {output_name}, NPY_RAISE);
Py_DECREF(indices);
if ({output_name} == NULL) {fail};
""".format(**locals())
"""
def c_code_cache_version(self):
return (0, 1, 2)
......@@ -2523,8 +2523,10 @@ class AdvancedIncSubtensor1(COp):
x, y, idx = input_names
out = output_names[0]
copy_of_x = self.copy_of_x(x)
params = sub["params"]
fail = sub["fail"]
return """
return f"""
PyObject* rval = NULL;
if ({params}->inplace)
{{
......@@ -2548,17 +2550,7 @@ class AdvancedIncSubtensor1(COp):
{fail};
}}
Py_XDECREF(rval);
""".format(
**dict(
x=x,
y=y,
idx=idx,
out=out,
copy_of_x=copy_of_x,
params=sub["params"],
fail=sub["fail"],
)
)
"""
def c_code_cache_version(self):
return (8,)
......
......@@ -475,25 +475,28 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
def c_declare(self, name, sub, check_input=True):
if check_input:
check = """
dtype = self.dtype_specs()[1]
check = f"""
typedef {dtype} dtype_{name};
""".format(**dict(sub, name=name, dtype=self.dtype_specs()[1]))
"""
else:
check = ""
declaration = """
declaration = f"""
PyArrayObject* {name};
""".format(**dict(sub, name=name, dtype=self.dtype_specs()[1]))
"""
return declaration + check
def c_init(self, name, sub):
return """
return f"""
{name} = NULL;
""".format(**dict(sub, name=name, type_num=self.dtype_specs()[2]))
"""
def c_extract(self, name, sub, check_input=True, **kwargs):
if check_input:
check = """
fail = sub["fail"]
type_num = self.dtype_specs()[2]
check = f"""
{name} = NULL;
if (py_{name} == Py_None) {{
// We can either fail here or set {name} to NULL and rely on Ops
......@@ -541,28 +544,27 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
{type_num}, PyArray_TYPE((PyArrayObject*) py_{name}));
{fail}
}}
""".format(**dict(sub, name=name, type_num=self.dtype_specs()[2]))
"""
else:
check = ""
return (
check
+ """
+ f"""
{name} = (PyArrayObject*)(py_{name});
Py_XINCREF({name});
""".format(**dict(sub, name=name, type_num=self.dtype_specs()[2]))
"""
)
def c_cleanup(self, name, sub):
return """
return f"""
if ({name}) {{
Py_XDECREF({name});
}}
""".format(**locals())
"""
def c_sync(self, name, sub):
fail = sub["fail"]
type_num = self.dtype_specs()[2]
return """
return f"""
{{Py_XDECREF(py_{name});}}
if (!{name}) {{
Py_INCREF(Py_None);
......@@ -597,7 +599,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
);
{fail}
}}
""".format(**locals())
"""
def c_headers(self, **kwargs):
return ps.get_scalar_type(self.dtype).c_headers(**kwargs)
......
......@@ -102,13 +102,13 @@ class GetItem(COp):
x_name, index = inp[0], inp[1]
output_name = out[0]
fail = sub["fail"]
return """
return f"""
{output_name} = (typeof {output_name}) PyList_GetItem( (PyObject*) {x_name}, *((npy_int64 *) PyArray_DATA({index})));
if({output_name} == NULL){{
{fail}
}}
Py_INCREF({output_name});
""".format(**locals())
"""
def c_code_cache_version(self):
return (1,)
......@@ -169,16 +169,16 @@ class Append(COp):
output_name = out[0]
fail = sub["fail"]
if not self.inplace:
init = """
init = f"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals())
"""
else:
init = f"""
{output_name} = {x_name};
"""
return (
init
+ """
+ f"""
if({output_name}==NULL){{
{fail}
}};
......@@ -186,7 +186,7 @@ class Append(COp):
{fail}
}};
Py_INCREF({output_name});
""".format(**locals())
"""
)
def c_code_cache_version(self):
......@@ -249,16 +249,16 @@ class Extend(COp):
output_name = out[0]
fail = sub["fail"]
if not self.inplace:
init = """
init = f"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals())
"""
else:
init = f"""
{output_name} = {x_name};
"""
return (
init
+ """
+ f"""
int i =0;
int length = PyList_GET_SIZE((PyObject*) {toAppend});
if({output_name}==NULL){{
......@@ -270,7 +270,7 @@ class Extend(COp):
}};
}}
Py_INCREF({output_name});
""".format(**locals())
"""
)
def c_code_cache_version_(self):
......@@ -337,16 +337,16 @@ class Insert(COp):
output_name = out[0]
fail = sub["fail"]
if not self.inplace:
init = """
init = f"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals())
"""
else:
init = f"""
{output_name} = {x_name};
"""
return (
init
+ """
+ f"""
if({output_name}==NULL){{
{fail}
}};
......@@ -354,7 +354,7 @@ class Insert(COp):
{fail}
}};
Py_INCREF({output_name});
""".format(**locals())
"""
)
def c_code_cache_version(self):
......@@ -465,16 +465,16 @@ class Reverse(COp):
output_name = out[0]
fail = sub["fail"]
if not self.inplace:
init = """
init = f"""
{output_name} = (PyListObject*) PyList_GetSlice((PyObject*) {x_name}, 0, PyList_GET_SIZE((PyObject*) {x_name})) ;
""".format(**locals())
"""
else:
init = f"""
{output_name} = {x_name};
"""
return (
init
+ """
+ f"""
if({output_name}==NULL){{
{fail}
}};
......@@ -482,7 +482,7 @@ class Reverse(COp):
{fail}
}};
Py_INCREF({output_name});
""".format(**locals())
"""
)
def c_code_cache_version(self):
......@@ -595,13 +595,12 @@ class Length(COp):
def c_code(self, node, name, inp, out, sub):
x_name = inp[0]
output_name = out[0]
fail = sub["fail"]
return """
return f"""
if(!{output_name})
{output_name}=(PyArrayObject*)PyArray_EMPTY(0, NULL, NPY_INT64, 0);
((npy_int64*)PyArray_DATA({output_name}))[0]=PyList_Size((PyObject*){x_name});
Py_INCREF({output_name});
""".format(**locals())
"""
def c_code_cache_version(self):
return (1,)
......
......@@ -110,26 +110,27 @@ class TypedListType(CType):
def c_extract(self, name, sub, check_input=True, **kwargs):
if check_input:
pre = """
fail = sub["fail"]
pre = f"""
if (!PyList_Check(py_{name})) {{
PyErr_SetString(PyExc_TypeError, "expected a list");
{fail}
}}""".format(**dict(name=name, fail=sub["fail"]))
}}"""
else:
pre = ""
return (
pre
+ """
+ f"""
{name} = (PyListObject*) (py_{name});
""".format(**dict(name=name, fail=sub["fail"]))
"""
)
def c_sync(self, name, sub):
return """
return f"""
Py_XDECREF(py_{name});
py_{name} = (PyObject*)({name});
Py_INCREF(py_{name});
""".format(**dict(name=name))
"""
def c_cleanup(self, name, sub):
return ""
......
......@@ -64,7 +64,8 @@ class BROKEN_ON_PURPOSE_Add(COp):
def c_code(self, node, name, inp, out, sub):
a, b = inp
(z,) = out
return """
fail = sub["fail"]
return f"""
if (PyArray_NDIM({a}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 1"); {fail};}}
if (PyArray_NDIM({b}) != 1) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 1"); {fail};}}
......@@ -96,7 +97,7 @@ class BROKEN_ON_PURPOSE_Add(COp):
+ ((double*)PyArray_GETPTR1({b}, m))[0] ;
}}
}}
""".format(**dict(locals(), **sub))
"""
# inconsistent is a invalid op, whose perform and c_code do not match
......@@ -632,7 +633,8 @@ class BrokenCImplementationAdd(COp):
a, b = inp
(z,) = out
debug = 0
return """
fail = sub["fail"]
return f"""
//printf("executing c_code\\n");
if (PyArray_NDIM({a}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(a) != 2"); {fail};}}
if (PyArray_NDIM({b}) != 2) {{PyErr_SetString(PyExc_NotImplementedError, "rank(b) != 2"); {fail};}}
......@@ -690,7 +692,7 @@ class BrokenCImplementationAdd(COp):
}}
}}
}}
""".format(**dict(locals(), **sub))
"""
class VecAsRowAndCol(Op):
......
......@@ -39,7 +39,8 @@ class TDouble(CType):
return str(data)
def c_extract(self, name, sub, check_input=True, **kwargs):
return """
fail = sub["fail"]
return f"""
if (!PyFloat_Check(py_{name})) {{
PyErr_SetString(PyExc_TypeError, "not a double!");
{fail}
......@@ -47,7 +48,7 @@ class TDouble(CType):
{name} = PyFloat_AsDouble(py_{name});
{name}_bad_thing = NULL;
//printf("Extracting {name}\\n");
""".format(**dict(locals(), **sub))
"""
def c_sync(self, name, sub):
return f"""
......
......@@ -189,12 +189,11 @@ def mock_system(request):
@pytest.fixture()
def cxx_search_dirs(blas_libs, mock_system):
libext = {"Linux": "so", "Windows": "dll", "Darwin": "dylib"}
libtemplate = f"{{lib}}.{libext[mock_system]}"
libraries = []
with tempfile.TemporaryDirectory() as d:
flags = None
for lib in blas_libs:
lib_path = Path(d) / libtemplate.format(lib=lib)
lib_path = Path(d) / f"{lib}.{libext[mock_system]}"
lib_path.write_bytes(b"1")
libraries.append(lib_path)
if flags is None:
......@@ -262,14 +261,13 @@ def test_default_blas_ldflags_no_cxx():
@pytest.fixture()
def windows_conda_libs(blas_libs):
libtemplate = "{lib}.dll"
libraries = []
with tempfile.TemporaryDirectory() as d:
subdir = Path(d) / "Library" / "bin"
subdir.mkdir(exist_ok=True, parents=True)
flags = f'-L"{subdir}"'
for lib in blas_libs:
lib_path = subdir / libtemplate.format(lib=lib)
lib_path = subdir / f"{lib}.dll"
lib_path.write_bytes(b"1")
libraries.append(lib_path)
flags += f" -l{lib}"
......
......@@ -79,10 +79,10 @@ class StructOp(COp):
return f"counter{name} = 0;"
def c_code(self, node, name, input_names, outputs_names, sub):
return """
{out} = counter{name};
return f"""
{outputs_names[0]} = counter{name};
counter{name}++;
""".format(**dict(out=outputs_names[0], name=name))
"""
def c_code_cache_version(self):
return (1,)
......
......@@ -73,30 +73,24 @@ class QuadraticOpFunc(COp):
"""
def c_code(self, node, name, inputs, outputs, sub):
return """
{float_type} a = ({float_type}) (*(npy_float64*) PyArray_GETPTR1({coeff}->a, 0)); // 0-D TensorType.
{float_type} b = {coeff}->b; // ScalarType.
{float_type} c = ({float_type}) PyFloat_AsDouble({coeff}->c); // Generic.
X = inputs[0]
Y = outputs[0]
float_type = node.inputs[0].type.c_element_type()
return f"""
{float_type} a = ({float_type}) (*(npy_float64*) PyArray_GETPTR1({sub['params']}->a, 0)); // 0-D TensorType.
{float_type} b = {sub['params']}->b; // ScalarType.
{float_type} c = ({float_type}) PyFloat_AsDouble({sub['params']}->c); // Generic.
Py_XDECREF({Y});
{Y} = (PyArrayObject*)PyArray_EMPTY(PyArray_NDIM({X}), PyArray_DIMS({X}), PyArray_TYPE({X}), PyArray_IS_F_CONTIGUOUS({X}));
if (PyArray_CopyInto({Y}, {X}) != 0) {{
PyErr_SetString(PyExc_RuntimeError, "Unable to copy input into output.");
{fail}
{sub['fail']}
}};
if (quadratic_{name}({Y}, a, b, c) != 0) {{
PyErr_SetString(PyExc_RuntimeError, "Unable to compute quadratic function.");
{fail}
{sub['fail']}
}}
""".format(
**dict(
name=name,
coeff=sub["params"],
fail=sub["fail"],
X=inputs[0],
Y=outputs[0],
float_type=node.inputs[0].type.c_element_type(),
)
)
"""
# Same op as above, but implemented as a ExternalCOp (with C code in an
......
......@@ -25,11 +25,12 @@ Py_XDECREF((PyObject *)p);
"""
def c_code(self, node, name, inps, outs, sub):
return """
Py_XDECREF({out});
{out} = (void *){inp};
Py_INCREF({inp});
""".format(**dict(out=outs[0], inp=inps[0]))
return f"""
Py_XDECREF({outs[0]});
{outs[0]} = (void *){inps[0]};
Py_INCREF({inps[0]});
"""
# FIXME: should it not be outs[0]?
def c_code_cache_version(self):
return (0,)
......@@ -52,11 +53,11 @@ Py_XDECREF((PyObject *)p);
"""
def c_code(self, node, name, inps, outs, sub):
return """
Py_XDECREF({out});
{out} = (PyArrayObject *){inp};
Py_INCREF({out});
""".format(**dict(out=outs[0], inp=inps[0]))
return f"""
Py_XDECREF({outs[0]});
{outs[0]} = (PyArrayObject *){inps[0]};
Py_INCREF({outs[0]});
"""
def c_code_cache_version(self):
return (0,)
......@@ -136,7 +137,12 @@ class MyOpEnumList(COp):
return (1,)
def c_code(self, node, name, inputs, outputs, sub):
return """
op = sub["params"]
o = outputs[0]
a = inputs[0]
b = inputs[1]
fail = sub["fail"]
return f"""
switch({op}) {{
case ADD:
{o} = {a} + {b};
......@@ -154,15 +160,7 @@ class MyOpEnumList(COp):
{{{fail}}}
break;
}}
""".format(
**dict(
op=sub["params"],
o=outputs[0],
a=inputs[0],
b=inputs[1],
fail=sub["fail"],
)
)
"""
class MyOpCEnumType(COp):
......@@ -201,15 +199,10 @@ class MyOpCEnumType(COp):
return (3,)
def c_code(self, node, name, inputs, outputs, sub):
return """
{o} = {val};
""".format(
**dict(
o=outputs[0],
# params in C code will already contains expected C constant value.
val=sub["params"],
)
)
# params in C code will already contains expected C constant value.
return f"""
{outputs[0]} = {sub['params']};
"""
class TestEnumTypes:
......
......@@ -323,7 +323,9 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
)
depth = "-1"
return """
fail = sub["fail"]
params = sub["params"]
return f"""
// Mandatory args
int direction = {params}->direction; // forward, bprop weights, bprop inputs
......@@ -568,18 +570,7 @@ class BaseCorr3dMM(OpenMPOp, _NoPythonOp):
}}
assert (out2 == *out);
""".format(
**dict(
bottom=bottom,
weights=weights,
top=top,
height=height,
width=width,
depth=depth,
fail=sub["fail"],
params=sub["params"],
)
)
"""
class Corr3dMM(BaseCorr3dMM):
......
......@@ -322,7 +322,9 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
)
width = "-1"
return """
fail = sub["fail"]
params = sub["params"]
return f"""
// Mandatory args
int direction = {params}->direction; // forward, bprop weights, bprop inputs
......@@ -614,17 +616,7 @@ class BaseCorrMM(OpenMPOp, _NoPythonOp):
}}
assert (out2 == *out);
""".format(
**dict(
bottom=bottom,
weights=weights,
top=top,
height=height,
width=width,
fail=sub["fail"],
params=sub["params"],
)
)
"""
class CorrMM(BaseCorrMM):
......
......@@ -1391,9 +1391,9 @@ class TimesN(ps.basic.UnaryScalarOp):
def c_support_code_apply(self, node, nodename):
n = str(self.n)
return """
return f"""
float {nodename}_timesn(float x) {{ return x * {n}; }}
""".format(**locals())
"""
def c_code(self, node, name, inputs, outputs, sub):
(x,) = inputs
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论