提交 10f285a1 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Use generators when appropriate

上级 8ae2a195
......@@ -104,7 +104,7 @@ class PyTensorConfigParser:
)
return hash_from_code(
"\n".join(
[f"{cv.name} = {cv.__get__(self, self.__class__)}" for cv in all_opts]
f"{cv.name} = {cv.__get__(self, self.__class__)}" for cv in all_opts
)
)
......
......@@ -360,7 +360,7 @@ def dict_to_pdnode(d):
for k, v in d.items():
if v is not None:
if isinstance(v, list):
v = "\t".join([str(x) for x in v])
v = "\t".join(str(x) for x in v)
else:
v = str(v)
v = str(v)
......
......@@ -1264,7 +1264,7 @@ class SequentialNodeRewriter(NodeRewriter):
return getattr(
self,
"__name__",
f"{type(self).__name__}({','.join([str(o) for o in self.rewrites])})",
f"{type(self).__name__}({','.join(str(o) for o in self.rewrites)})",
)
def tracks(self):
......@@ -1666,7 +1666,7 @@ class PatternNodeRewriter(NodeRewriter):
if isinstance(pattern, list | tuple):
return "{}({})".format(
str(pattern[0]),
", ".join([pattern_to_str(p) for p in pattern[1:]]),
", ".join(pattern_to_str(p) for p in pattern[1:]),
)
elif isinstance(pattern, dict):
return "{} subject to {}".format(
......@@ -2569,7 +2569,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
d = sorted(
loop_process_count[i].items(), key=lambda a: a[1], reverse=True
)
loop_times = " ".join([str((str(k), v)) for k, v in d[:5]])
loop_times = " ".join(str((str(k), v)) for k, v in d[:5])
if len(d) > 5:
loop_times += " ..."
print(
......
......@@ -235,16 +235,16 @@ def struct_gen(args, struct_builders, blocks, sub):
behavior = code_gen(blocks)
# declares the storage
storage_decl = "\n".join([f"PyObject* {arg};" for arg in args])
storage_decl = "\n".join(f"PyObject* {arg};" for arg in args)
# in the constructor, sets the storage to the arguments
storage_set = "\n".join([f"this->{arg} = {arg};" for arg in args])
storage_set = "\n".join(f"this->{arg} = {arg};" for arg in args)
# increments the storage's refcount in the constructor
storage_incref = "\n".join([f"Py_XINCREF({arg});" for arg in args])
storage_incref = "\n".join(f"Py_XINCREF({arg});" for arg in args)
# decrements the storage's refcount in the destructor
storage_decref = "\n".join([f"Py_XDECREF(this->{arg});" for arg in args])
storage_decref = "\n".join(f"Py_XDECREF(this->{arg});" for arg in args)
args_names = ", ".join(args)
args_decl = ", ".join([f"PyObject* {arg}" for arg in args])
args_decl = ", ".join(f"PyObject* {arg}" for arg in args)
# The following code stores the exception data in __ERROR, which
# is a special field of the struct. __ERROR is a list of length 3
......
......@@ -2003,7 +2003,7 @@ def try_blas_flag(flags):
cflags = list(flags)
# to support path that includes spaces, we need to wrap it with double quotes on Windows
path_wrapper = '"' if os.name == "nt" else ""
cflags.extend([f"-L{path_wrapper}{d}{path_wrapper}" for d in std_lib_dirs()])
cflags.extend(f"-L{path_wrapper}{d}{path_wrapper}" for d in std_lib_dirs())
res = GCC_compiler.try_compile_tmp(
test_code, tmp_prefix="try_blas_", flags=cflags, try_run=True
......@@ -2573,8 +2573,8 @@ class GCC_compiler(Compiler):
cmd.extend(preargs)
# to support path that includes spaces, we need to wrap it with double quotes on Windows
path_wrapper = '"' if os.name == "nt" else ""
cmd.extend([f"-I{path_wrapper}{idir}{path_wrapper}" for idir in include_dirs])
cmd.extend([f"-L{path_wrapper}{ldir}{path_wrapper}" for ldir in lib_dirs])
cmd.extend(f"-I{path_wrapper}{idir}{path_wrapper}" for idir in include_dirs)
cmd.extend(f"-L{path_wrapper}{ldir}{path_wrapper}" for ldir in lib_dirs)
if hide_symbols and sys.platform != "win32":
# This has been available since gcc 4.0 so we suppose it
# is always available. We pass it here since it
......
......@@ -263,9 +263,7 @@ class Params(dict):
def __repr__(self):
return "Params({})".format(
", ".join(
[(f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self)]
)
", ".join((f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self))
)
def __getattr__(self, key):
......@@ -425,9 +423,7 @@ class ParamsType(CType):
def __repr__(self):
return "ParamsType<{}>".format(
", ".join(
[(f"{self.fields[i]}:{self.types[i]}") for i in range(self.length)]
)
", ".join((f"{self.fields[i]}:{self.types[i]}") for i in range(self.length))
)
def __eq__(self, other):
......@@ -748,10 +744,8 @@ class ParamsType(CType):
}}
""".format(
"\n".join(
[
("case %d: extract_%s(object); break;" % (i, self.fields[i]))
for i in range(self.length)
]
("case %d: extract_%s(object); break;" % (i, self.fields[i]))
for i in range(self.length)
)
)
final_struct_code = """
......
......@@ -485,8 +485,8 @@ def numba_funcify_Elemwise(op, node, **kwargs):
nout = len(node.outputs)
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs])
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs])
input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)
output_dtypes = tuple(out.type.dtype for out in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items())
core_output_shapes = tuple(() for _ in range(nout))
......
......@@ -85,9 +85,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_"
)
input_names = ", ".join(
[unique_names(v, force_unique=True) for v in node.inputs]
)
input_names = ", ".join(unique_names(v, force_unique=True) for v in node.inputs)
if not has_pyx_skip_dispatch:
scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}):
......@@ -115,10 +113,8 @@ def {scalar_op_fn_name}({input_names}):
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
converted_call_args = ", ".join(
[
f"direct_cast({i_name}, {i_tmp_dtype_name})"
for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names)
]
f"direct_cast({i_name}, {i_tmp_dtype_name})"
for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names)
)
if not has_pyx_skip_dispatch:
scalar_op_src = f"""
......
......@@ -373,7 +373,7 @@ def numba_funcify_Scan(op, node, **kwargs):
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
inner_out_to_outer_out_stmts = "\n".join(
[f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)]
f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)
)
scan_op_src = f"""
......
......@@ -35,10 +35,8 @@ def numba_funcify_AllocEmpty(op, node, **kwargs):
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
[
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
),
" " * 4,
)
......@@ -69,10 +67,8 @@ def numba_funcify_Alloc(op, node, **kwargs):
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
[
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
),
" " * 4,
)
......
......@@ -43,10 +43,8 @@ def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable:
out_signature = ", ".join(outputs)
inner_out_signature = ", ".join(inner_outputs)
store_outputs = "\n".join(
[
f"{output}[...] = {inner_output}"
for output, inner_output in zip(outputs, inner_outputs)
]
f"{output}[...] = {inner_output}"
for output, inner_output in zip(outputs, inner_outputs)
)
func_src = f"""
def store_core_outputs({inp_signature}, {out_signature}):
......
......@@ -1112,7 +1112,7 @@ class VMLinker(LocalLinker):
for i, node in enumerate(nodes):
prereq_var_idxs = []
for prereq_node in ords.get(node, []):
prereq_var_idxs.extend([vars_idx[v] for v in prereq_node.outputs])
prereq_var_idxs.extend(vars_idx[v] for v in prereq_node.outputs)
prereq_var_idxs = list(set(prereq_var_idxs))
prereq_var_idxs.sort() # TODO: why sort?
node_prereqs.append(prereq_var_idxs)
......@@ -1323,9 +1323,7 @@ class VMLinker(LocalLinker):
def __repr__(self):
args_str = ", ".join(
[
f"{name}={getattr(self, name)}"
for name in ("use_cloop", "lazy", "allow_partial_eval", "allow_gc")
]
f"{name}={getattr(self, name)}"
for name in ("use_cloop", "lazy", "allow_partial_eval", "allow_gc")
)
return f"{type(self).__name__}({args_str})"
......@@ -9,10 +9,10 @@ from pytensor.configdefaults import config
DISPLAY_DUPLICATE_KEYS = False
DISPLAY_MOST_FREQUENT_DUPLICATE_CCODE = False
dirs = []
dirs: list = []
if len(sys.argv) > 1:
for compiledir in sys.argv[1:]:
dirs.extend([os.path.join(compiledir, d) for d in os.listdir(compiledir)])
dirs.extend(os.path.join(compiledir, d) for d in os.listdir(compiledir))
else:
dirs = os.listdir(config.compiledir)
dirs = [os.path.join(config.compiledir, d) for d in dirs]
......
......@@ -229,32 +229,32 @@ def debugprint(
topo_orders.append(None)
elif isinstance(obj, Apply):
outputs_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs])
storage_maps.extend([None for item in obj.outputs])
topo_orders.extend([None for item in obj.outputs])
profile_list.extend(None for item in obj.outputs)
storage_maps.extend(None for item in obj.outputs)
topo_orders.extend(None for item in obj.outputs)
elif isinstance(obj, Function):
if print_fgraph_inputs:
inputs_to_print.extend(obj.maker.fgraph.inputs)
outputs_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs])
profile_list.extend(obj.profile for item in obj.maker.fgraph.outputs)
if print_storage:
storage_maps.extend(
[obj.vm.storage_map for item in obj.maker.fgraph.outputs]
obj.vm.storage_map for item in obj.maker.fgraph.outputs
)
else:
storage_maps.extend([None for item in obj.maker.fgraph.outputs])
storage_maps.extend(None for item in obj.maker.fgraph.outputs)
topo = obj.maker.fgraph.toposort()
topo_orders.extend([topo for item in obj.maker.fgraph.outputs])
topo_orders.extend(topo for item in obj.maker.fgraph.outputs)
elif isinstance(obj, FunctionGraph):
if print_fgraph_inputs:
inputs_to_print.extend(obj.inputs)
outputs_to_print.extend(obj.outputs)
profile_list.extend([getattr(obj, "profile", None) for item in obj.outputs])
profile_list.extend(getattr(obj, "profile", None) for item in obj.outputs)
storage_maps.extend(
[getattr(obj, "storage_map", None) for item in obj.outputs]
getattr(obj, "storage_map", None) for item in obj.outputs
)
topo = obj.toposort()
topo_orders.extend([topo for item in obj.outputs])
topo_orders.extend(topo for item in obj.outputs)
elif isinstance(obj, int | float | np.ndarray):
print(obj, file=_file)
elif isinstance(obj, In | Out):
......@@ -980,10 +980,10 @@ class FunctionPrinter(Printer):
name = self.names[idx]
with set_precedence(pstate):
inputs_str = ", ".join(
[pprinter.process(input, pstate) for input in node.inputs]
pprinter.process(input, pstate) for input in node.inputs
)
keywords_str = ", ".join(
[f"{kw}={getattr(node.op, kw)}" for kw in self.keywords]
f"{kw}={getattr(node.op, kw)}" for kw in self.keywords
)
if keywords_str and inputs_str:
......@@ -1050,7 +1050,7 @@ class DefaultPrinter(Printer):
with set_precedence(pstate):
r = "{}({})".format(
str(node.op),
", ".join([pprinter.process(input, pstate) for input in node.inputs]),
", ".join(pprinter.process(input, pstate) for input in node.inputs),
)
pstate.memo[output] = r
......
......@@ -4224,8 +4224,8 @@ class Composite(ScalarInnerGraphOp):
inputs, outputs = res[0], res2[1]
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
self.inputs_type = tuple([input.type for input in self.inputs])
self.outputs_type = tuple([output.type for output in self.outputs])
self.inputs_type = tuple(input.type for input in self.inputs)
self.outputs_type = tuple(output.type for output in self.outputs)
self.nin = len(inputs)
self.nout = len(outputs)
super().__init__()
......@@ -4247,7 +4247,7 @@ class Composite(ScalarInnerGraphOp):
if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
self._name = "Composite{...}"
else:
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs])
outputs_str = ", ".join(pprint(output) for output in self.fgraph.outputs)
self._name = f"Composite{{{outputs_str}}}"
return self._name
......@@ -4295,7 +4295,7 @@ class Composite(ScalarInnerGraphOp):
return self.outputs_type
def make_node(self, *inputs):
if tuple([i.type for i in self.inputs]) == tuple([i.type for i in inputs]):
if tuple(i.type for i in self.inputs) == tuple(i.type for i in inputs):
return super().make_node(*inputs)
else:
# Make a new op with the right input type.
......
......@@ -160,7 +160,7 @@ class ScalarLoop(ScalarInnerGraphOp):
f"Got {n_steps.type.dtype}",
)
if self.inputs_type == tuple([i.type for i in inputs]):
if self.inputs_type == tuple(i.type for i in inputs):
return super().make_node(n_steps, *inputs)
else:
# Make a new op with the right input types.
......
......@@ -1936,7 +1936,7 @@ class ScanMerge(GraphRewriter):
profile=old_op.profile,
truncate_gradient=old_op.truncate_gradient,
allow_gc=old_op.allow_gc,
name="&".join([nd.op.name for nd in nodes]),
name="&".join(nd.op.name for nd in nodes),
)
new_outs = new_op(*outer_ins)
......
......@@ -749,15 +749,13 @@ class ScanArgs:
def field_names(self):
res = ["mit_mot_out_slices", "mit_mot_in_slices", "mit_sot_in_slices"]
res.extend(
[
attr
for attr in self.__dict__
if attr.startswith("inner_in")
or attr.startswith("inner_out")
or attr.startswith("outer_in")
or attr.startswith("outer_out")
or attr == "n_steps"
]
attr
for attr in self.__dict__
if attr.startswith("inner_in")
or attr.startswith("inner_out")
or attr.startswith("outer_in")
or attr.startswith("outer_out")
or attr == "n_steps"
)
return res
......
......@@ -1554,7 +1554,7 @@ class Alloc(COp):
def perform(self, node, inputs, out_):
(out,) = out_
v = inputs[0]
sh = tuple([int(i) for i in inputs[1:]])
sh = tuple(int(i) for i in inputs[1:])
self._check_runtime_broadcast(node, v, sh)
if out[0] is None or out[0].shape != sh:
......@@ -4180,7 +4180,7 @@ class AllocEmpty(COp):
def perform(self, node, inputs, out_):
(out,) = out_
sh = tuple([int(i) for i in inputs])
sh = tuple(int(i) for i in inputs)
if out[0] is None or out[0].shape != sh:
out[0] = np.empty(sh, dtype=self.dtype)
......
......@@ -1691,7 +1691,7 @@ class BatchedDot(COp):
if x.shape[0] != y.shape[0]:
raise TypeError(
f"Inputs [{', '.join(map(str, inp))}] must have the"
f" same size in axis 0, but have sizes [{', '.join([str(i.shape[0]) for i in inp])}]."
f" same size in axis 0, but have sizes [{', '.join(str(i.shape[0]) for i in inp)}]."
)
z[0] = np.matmul(x, y)
......
......@@ -139,10 +139,8 @@ class Blockwise(Op):
try:
batch_shape = tuple(
[
broadcast_static_dim_lengths(batch_dims)
for batch_dims in zip(*batch_shapes)
]
broadcast_static_dim_lengths(batch_dims)
for batch_dims in zip(*batch_shapes)
)
except ValueError:
raise ValueError(
......
......@@ -182,7 +182,7 @@ class DimShuffle(ExternalCOp):
self.transposition = self.shuffle + drop
# List of dimensions of the output that are broadcastable and were not
# in the original input
self.augment = sorted([i for i, x in enumerate(new_order) if x == "x"])
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.drop = drop
if self.inplace:
......@@ -893,11 +893,9 @@ class Elemwise(OpenMPOp):
# In that case, create a fortran output ndarray.
z = list(zip(inames, inputs))
alloc_fortran = " && ".join(
[
f"PyArray_ISFORTRAN({arr})"
for arr, var in z
if not all(s == 1 for s in var.type.shape)
]
f"PyArray_ISFORTRAN({arr})"
for arr, var in z
if not all(s == 1 for s in var.type.shape)
)
# If it is a scalar, make it c contig to prevent problem with
# NumPy C and F contig not always set as both of them.
......@@ -984,12 +982,10 @@ class Elemwise(OpenMPOp):
if len(all_code) == 1:
# No loops
task_decl = "".join(
[
f"{dtype}& {name}_i = *{name}_iter;\n"
for name, dtype in zip(
inames + list(real_onames), idtypes + list(real_odtypes)
)
]
f"{dtype}& {name}_i = *{name}_iter;\n"
for name, dtype in zip(
inames + list(real_onames), idtypes + list(real_odtypes)
)
)
preloops = {}
......@@ -1101,18 +1097,14 @@ class Elemwise(OpenMPOp):
z = list(zip(inames + onames, inputs + node.outputs))
all_broadcastable = all(s == 1 for s in var.type.shape)
cond1 = " && ".join(
[
f"PyArray_ISCONTIGUOUS({arr})"
for arr, var in z
if not all_broadcastable
]
f"PyArray_ISCONTIGUOUS({arr})"
for arr, var in z
if not all_broadcastable
)
cond2 = " && ".join(
[
f"PyArray_ISFORTRAN({arr})"
for arr, var in z
if not all_broadcastable
]
f"PyArray_ISFORTRAN({arr})"
for arr, var in z
if not all_broadcastable
)
loop = """
if(({cond1}) || ({cond2})){{
......
......@@ -1248,7 +1248,7 @@ class Unique(Op):
f"Unique axis `{self.axis}` is outside of input ndim = {ndim}."
)
ret[0] = tuple(
[fgraph.shape_feature.shape_ir(i, node.outputs[0]) for i in range(ndim)]
fgraph.shape_feature.shape_ir(i, node.outputs[0]) for i in range(ndim)
)
if self.return_inverse:
if self.axis is None:
......
......@@ -258,11 +258,9 @@ class Argmax(COp):
if self.axis is None:
return [()]
rval = tuple(
[
ishape[i]
for (i, b) in enumerate(node.inputs[0].type.broadcastable)
if i not in self.axis
]
ishape[i]
for (i, b) in enumerate(node.inputs[0].type.broadcastable)
if i not in self.axis
)
return [rval]
......
......@@ -800,10 +800,8 @@ class Reshape(COp):
rest_size = input_size // maximum(requ_size, 1)
return [
tuple(
[
ptb.switch(eq(requ[i], -1), rest_size, requ[i])
for i in range(self.ndim)
]
ptb.switch(eq(requ[i], -1), rest_size, requ[i])
for i in range(self.ndim)
)
]
......
......@@ -879,7 +879,7 @@ class BaseBlockDiagonal(Op):
__props__ = ("n_inputs",)
def __init__(self, n_inputs):
input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)])
input_sig = ",".join(f"(m{i},n{i})" for i in range(n_inputs))
self.gufunc_signature = f"{input_sig}->(m,n)"
if n_inputs == 0:
......
......@@ -1113,7 +1113,7 @@ class Subtensor(COp):
if is_slice:
is_slice_init = (
"int is_slice[] = {" + ",".join([str(s) for s in is_slice]) + "};"
"int is_slice[] = {" + ",".join(str(s) for s in is_slice) + "};"
)
else:
is_slice_init = "int* is_slice = NULL;"
......@@ -2401,9 +2401,7 @@ class AdvancedIncSubtensor1(COp):
fn_array = (
"static inplace_map_binop addition_funcs[] = {"
+ "".join(
[gen_binop(type=t, typen=t.upper()) for t in types + complex_types]
)
+ "".join(gen_binop(type=t, typen=t.upper()) for t in types + complex_types)
+ "NULL};\n"
)
......@@ -2416,7 +2414,7 @@ class AdvancedIncSubtensor1(COp):
type_number_array = (
"static int type_numbers[] = {"
+ "".join([gen_num(typen=t.upper()) for t in types + complex_types])
+ "".join(gen_num(typen=t.upper()) for t in types + complex_types)
+ "-1000};"
)
......
......@@ -401,7 +401,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
else:
return str(s)
formatted_shape = ", ".join([shape_str(s) for s in shape])
formatted_shape = ", ".join(shape_str(s) for s in shape)
if len_shape == 1:
formatted_shape += ","
......
......@@ -521,12 +521,10 @@ class _tensor_py_operators:
# Else leave it as is if it is a real number
# Convert python literals to pytensor constants
args = tuple(
[
pt.subtensor.as_index_constant(
np.array(inp, dtype=np.uint8) if is_empty_array(inp) else inp
)
for inp in args
]
pt.subtensor.as_index_constant(
np.array(inp, dtype=np.uint8) if is_empty_array(inp) else inp
)
for inp in args
)
# Determine if advanced indexing is needed or not. The logic is
......
......@@ -3418,7 +3418,7 @@ class TestSumMeanMaxMinArgMaxVarReduceAxes:
def reduce_bitwise_and(x, axis=-1, dtype="int8"):
identity = np.array((-1,), dtype=dtype)[0]
shape_without_axis = tuple([s for i, s in enumerate(x.shape) if i != axis])
shape_without_axis = tuple(s for i, s in enumerate(x.shape) if i != axis)
if 0 in shape_without_axis:
return np.empty(shape=shape_without_axis, dtype=x.dtype)
......
......@@ -667,7 +667,7 @@ def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs):
# For instance: sub_inplace -> SubInplace
capitalize = True
if capitalize:
name = "".join([x.capitalize() for x in name.split("_")])
name = "".join(x.capitalize() for x in name.split("_"))
# Some tests specify a name that already ends with 'Tester', while in other
# cases we need to add it manually.
if not name.endswith("Tester"):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论