提交 10f285a1 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Use generators when appropriate

上级 8ae2a195
...@@ -104,7 +104,7 @@ class PyTensorConfigParser: ...@@ -104,7 +104,7 @@ class PyTensorConfigParser:
) )
return hash_from_code( return hash_from_code(
"\n".join( "\n".join(
[f"{cv.name} = {cv.__get__(self, self.__class__)}" for cv in all_opts] f"{cv.name} = {cv.__get__(self, self.__class__)}" for cv in all_opts
) )
) )
......
...@@ -360,7 +360,7 @@ def dict_to_pdnode(d): ...@@ -360,7 +360,7 @@ def dict_to_pdnode(d):
for k, v in d.items(): for k, v in d.items():
if v is not None: if v is not None:
if isinstance(v, list): if isinstance(v, list):
v = "\t".join([str(x) for x in v]) v = "\t".join(str(x) for x in v)
else: else:
v = str(v) v = str(v)
v = str(v) v = str(v)
......
...@@ -1264,7 +1264,7 @@ class SequentialNodeRewriter(NodeRewriter): ...@@ -1264,7 +1264,7 @@ class SequentialNodeRewriter(NodeRewriter):
return getattr( return getattr(
self, self,
"__name__", "__name__",
f"{type(self).__name__}({','.join([str(o) for o in self.rewrites])})", f"{type(self).__name__}({','.join(str(o) for o in self.rewrites)})",
) )
def tracks(self): def tracks(self):
...@@ -1666,7 +1666,7 @@ class PatternNodeRewriter(NodeRewriter): ...@@ -1666,7 +1666,7 @@ class PatternNodeRewriter(NodeRewriter):
if isinstance(pattern, list | tuple): if isinstance(pattern, list | tuple):
return "{}({})".format( return "{}({})".format(
str(pattern[0]), str(pattern[0]),
", ".join([pattern_to_str(p) for p in pattern[1:]]), ", ".join(pattern_to_str(p) for p in pattern[1:]),
) )
elif isinstance(pattern, dict): elif isinstance(pattern, dict):
return "{} subject to {}".format( return "{} subject to {}".format(
...@@ -2569,7 +2569,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter): ...@@ -2569,7 +2569,7 @@ class EquilibriumGraphRewriter(NodeProcessingGraphRewriter):
d = sorted( d = sorted(
loop_process_count[i].items(), key=lambda a: a[1], reverse=True loop_process_count[i].items(), key=lambda a: a[1], reverse=True
) )
loop_times = " ".join([str((str(k), v)) for k, v in d[:5]]) loop_times = " ".join(str((str(k), v)) for k, v in d[:5])
if len(d) > 5: if len(d) > 5:
loop_times += " ..." loop_times += " ..."
print( print(
......
...@@ -235,16 +235,16 @@ def struct_gen(args, struct_builders, blocks, sub): ...@@ -235,16 +235,16 @@ def struct_gen(args, struct_builders, blocks, sub):
behavior = code_gen(blocks) behavior = code_gen(blocks)
# declares the storage # declares the storage
storage_decl = "\n".join([f"PyObject* {arg};" for arg in args]) storage_decl = "\n".join(f"PyObject* {arg};" for arg in args)
# in the constructor, sets the storage to the arguments # in the constructor, sets the storage to the arguments
storage_set = "\n".join([f"this->{arg} = {arg};" for arg in args]) storage_set = "\n".join(f"this->{arg} = {arg};" for arg in args)
# increments the storage's refcount in the constructor # increments the storage's refcount in the constructor
storage_incref = "\n".join([f"Py_XINCREF({arg});" for arg in args]) storage_incref = "\n".join(f"Py_XINCREF({arg});" for arg in args)
# decrements the storage's refcount in the destructor # decrements the storage's refcount in the destructor
storage_decref = "\n".join([f"Py_XDECREF(this->{arg});" for arg in args]) storage_decref = "\n".join(f"Py_XDECREF(this->{arg});" for arg in args)
args_names = ", ".join(args) args_names = ", ".join(args)
args_decl = ", ".join([f"PyObject* {arg}" for arg in args]) args_decl = ", ".join(f"PyObject* {arg}" for arg in args)
# The following code stores the exception data in __ERROR, which # The following code stores the exception data in __ERROR, which
# is a special field of the struct. __ERROR is a list of length 3 # is a special field of the struct. __ERROR is a list of length 3
......
...@@ -2003,7 +2003,7 @@ def try_blas_flag(flags): ...@@ -2003,7 +2003,7 @@ def try_blas_flag(flags):
cflags = list(flags) cflags = list(flags)
# to support path that includes spaces, we need to wrap it with double quotes on Windows # to support path that includes spaces, we need to wrap it with double quotes on Windows
path_wrapper = '"' if os.name == "nt" else "" path_wrapper = '"' if os.name == "nt" else ""
cflags.extend([f"-L{path_wrapper}{d}{path_wrapper}" for d in std_lib_dirs()]) cflags.extend(f"-L{path_wrapper}{d}{path_wrapper}" for d in std_lib_dirs())
res = GCC_compiler.try_compile_tmp( res = GCC_compiler.try_compile_tmp(
test_code, tmp_prefix="try_blas_", flags=cflags, try_run=True test_code, tmp_prefix="try_blas_", flags=cflags, try_run=True
...@@ -2573,8 +2573,8 @@ class GCC_compiler(Compiler): ...@@ -2573,8 +2573,8 @@ class GCC_compiler(Compiler):
cmd.extend(preargs) cmd.extend(preargs)
# to support path that includes spaces, we need to wrap it with double quotes on Windows # to support path that includes spaces, we need to wrap it with double quotes on Windows
path_wrapper = '"' if os.name == "nt" else "" path_wrapper = '"' if os.name == "nt" else ""
cmd.extend([f"-I{path_wrapper}{idir}{path_wrapper}" for idir in include_dirs]) cmd.extend(f"-I{path_wrapper}{idir}{path_wrapper}" for idir in include_dirs)
cmd.extend([f"-L{path_wrapper}{ldir}{path_wrapper}" for ldir in lib_dirs]) cmd.extend(f"-L{path_wrapper}{ldir}{path_wrapper}" for ldir in lib_dirs)
if hide_symbols and sys.platform != "win32": if hide_symbols and sys.platform != "win32":
# This has been available since gcc 4.0 so we suppose it # This has been available since gcc 4.0 so we suppose it
# is always available. We pass it here since it # is always available. We pass it here since it
......
...@@ -263,9 +263,7 @@ class Params(dict): ...@@ -263,9 +263,7 @@ class Params(dict):
def __repr__(self): def __repr__(self):
return "Params({})".format( return "Params({})".format(
", ".join( ", ".join((f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self))
[(f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self)]
)
) )
def __getattr__(self, key): def __getattr__(self, key):
...@@ -425,9 +423,7 @@ class ParamsType(CType): ...@@ -425,9 +423,7 @@ class ParamsType(CType):
def __repr__(self): def __repr__(self):
return "ParamsType<{}>".format( return "ParamsType<{}>".format(
", ".join( ", ".join((f"{self.fields[i]}:{self.types[i]}") for i in range(self.length))
[(f"{self.fields[i]}:{self.types[i]}") for i in range(self.length)]
)
) )
def __eq__(self, other): def __eq__(self, other):
...@@ -748,10 +744,8 @@ class ParamsType(CType): ...@@ -748,10 +744,8 @@ class ParamsType(CType):
}} }}
""".format( """.format(
"\n".join( "\n".join(
[ ("case %d: extract_%s(object); break;" % (i, self.fields[i]))
("case %d: extract_%s(object); break;" % (i, self.fields[i])) for i in range(self.length)
for i in range(self.length)
]
) )
) )
final_struct_code = """ final_struct_code = """
......
...@@ -485,8 +485,8 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -485,8 +485,8 @@ def numba_funcify_Elemwise(op, node, **kwargs):
nout = len(node.outputs) nout = len(node.outputs)
core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout) core_op_fn = store_core_outputs(scalar_op_fn, nin=nin, nout=nout)
input_bc_patterns = tuple([inp.type.broadcastable for inp in node.inputs]) input_bc_patterns = tuple(inp.type.broadcastable for inp in node.inputs)
output_bc_patterns = tuple([out.type.broadcastable for out in node.outputs]) output_bc_patterns = tuple(out.type.broadcastable for out in node.outputs)
output_dtypes = tuple(out.type.dtype for out in node.outputs) output_dtypes = tuple(out.type.dtype for out in node.outputs)
inplace_pattern = tuple(op.inplace_pattern.items()) inplace_pattern = tuple(op.inplace_pattern.items())
core_output_shapes = tuple(() for _ in range(nout)) core_output_shapes = tuple(() for _ in range(nout))
......
...@@ -85,9 +85,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs): ...@@ -85,9 +85,7 @@ def numba_funcify_ScalarOp(op, node, **kwargs):
unique_names = unique_name_generator( unique_names = unique_name_generator(
[scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_" [scalar_op_fn_name, "scalar_func_numba"], suffix_sep="_"
) )
input_names = ", ".join( input_names = ", ".join(unique_names(v, force_unique=True) for v in node.inputs)
[unique_names(v, force_unique=True) for v in node.inputs]
)
if not has_pyx_skip_dispatch: if not has_pyx_skip_dispatch:
scalar_op_src = f""" scalar_op_src = f"""
def {scalar_op_fn_name}({input_names}): def {scalar_op_fn_name}({input_names}):
...@@ -115,10 +113,8 @@ def {scalar_op_fn_name}({input_names}): ...@@ -115,10 +113,8 @@ def {scalar_op_fn_name}({input_names}):
input_names = [unique_names(v, force_unique=True) for v in node.inputs] input_names = [unique_names(v, force_unique=True) for v in node.inputs]
converted_call_args = ", ".join( converted_call_args = ", ".join(
[ f"direct_cast({i_name}, {i_tmp_dtype_name})"
f"direct_cast({i_name}, {i_tmp_dtype_name})" for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names)
for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names)
]
) )
if not has_pyx_skip_dispatch: if not has_pyx_skip_dispatch:
scalar_op_src = f""" scalar_op_src = f"""
......
...@@ -373,7 +373,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -373,7 +373,7 @@ def numba_funcify_Scan(op, node, **kwargs):
inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts) inner_out_post_processing_block = "\n".join(inner_out_post_processing_stmts)
inner_out_to_outer_out_stmts = "\n".join( inner_out_to_outer_out_stmts = "\n".join(
[f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)] f"{s} = {d}" for s, d in zip(inner_out_to_outer_in_stmts, inner_output_names)
) )
scan_op_src = f""" scan_op_src = f"""
......
...@@ -35,10 +35,8 @@ def numba_funcify_AllocEmpty(op, node, **kwargs): ...@@ -35,10 +35,8 @@ def numba_funcify_AllocEmpty(op, node, **kwargs):
shape_var_item_names = [f"{name}_item" for name in shape_var_names] shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent( shapes_to_items_src = indent(
"\n".join( "\n".join(
[ f"{item_name} = to_scalar({shape_name})"
f"{item_name} = to_scalar({shape_name})" for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
), ),
" " * 4, " " * 4,
) )
...@@ -69,10 +67,8 @@ def numba_funcify_Alloc(op, node, **kwargs): ...@@ -69,10 +67,8 @@ def numba_funcify_Alloc(op, node, **kwargs):
shape_var_item_names = [f"{name}_item" for name in shape_var_names] shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent( shapes_to_items_src = indent(
"\n".join( "\n".join(
[ f"{item_name} = to_scalar({shape_name})"
f"{item_name} = to_scalar({shape_name})" for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
), ),
" " * 4, " " * 4,
) )
......
...@@ -43,10 +43,8 @@ def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable: ...@@ -43,10 +43,8 @@ def store_core_outputs(core_op_fn: Callable, nin: int, nout: int) -> Callable:
out_signature = ", ".join(outputs) out_signature = ", ".join(outputs)
inner_out_signature = ", ".join(inner_outputs) inner_out_signature = ", ".join(inner_outputs)
store_outputs = "\n".join( store_outputs = "\n".join(
[ f"{output}[...] = {inner_output}"
f"{output}[...] = {inner_output}" for output, inner_output in zip(outputs, inner_outputs)
for output, inner_output in zip(outputs, inner_outputs)
]
) )
func_src = f""" func_src = f"""
def store_core_outputs({inp_signature}, {out_signature}): def store_core_outputs({inp_signature}, {out_signature}):
......
...@@ -1112,7 +1112,7 @@ class VMLinker(LocalLinker): ...@@ -1112,7 +1112,7 @@ class VMLinker(LocalLinker):
for i, node in enumerate(nodes): for i, node in enumerate(nodes):
prereq_var_idxs = [] prereq_var_idxs = []
for prereq_node in ords.get(node, []): for prereq_node in ords.get(node, []):
prereq_var_idxs.extend([vars_idx[v] for v in prereq_node.outputs]) prereq_var_idxs.extend(vars_idx[v] for v in prereq_node.outputs)
prereq_var_idxs = list(set(prereq_var_idxs)) prereq_var_idxs = list(set(prereq_var_idxs))
prereq_var_idxs.sort() # TODO: why sort? prereq_var_idxs.sort() # TODO: why sort?
node_prereqs.append(prereq_var_idxs) node_prereqs.append(prereq_var_idxs)
...@@ -1323,9 +1323,7 @@ class VMLinker(LocalLinker): ...@@ -1323,9 +1323,7 @@ class VMLinker(LocalLinker):
def __repr__(self): def __repr__(self):
args_str = ", ".join( args_str = ", ".join(
[ f"{name}={getattr(self, name)}"
f"{name}={getattr(self, name)}" for name in ("use_cloop", "lazy", "allow_partial_eval", "allow_gc")
for name in ("use_cloop", "lazy", "allow_partial_eval", "allow_gc")
]
) )
return f"{type(self).__name__}({args_str})" return f"{type(self).__name__}({args_str})"
...@@ -9,10 +9,10 @@ from pytensor.configdefaults import config ...@@ -9,10 +9,10 @@ from pytensor.configdefaults import config
DISPLAY_DUPLICATE_KEYS = False DISPLAY_DUPLICATE_KEYS = False
DISPLAY_MOST_FREQUENT_DUPLICATE_CCODE = False DISPLAY_MOST_FREQUENT_DUPLICATE_CCODE = False
dirs = [] dirs: list = []
if len(sys.argv) > 1: if len(sys.argv) > 1:
for compiledir in sys.argv[1:]: for compiledir in sys.argv[1:]:
dirs.extend([os.path.join(compiledir, d) for d in os.listdir(compiledir)]) dirs.extend(os.path.join(compiledir, d) for d in os.listdir(compiledir))
else: else:
dirs = os.listdir(config.compiledir) dirs = os.listdir(config.compiledir)
dirs = [os.path.join(config.compiledir, d) for d in dirs] dirs = [os.path.join(config.compiledir, d) for d in dirs]
......
...@@ -229,32 +229,32 @@ def debugprint( ...@@ -229,32 +229,32 @@ def debugprint(
topo_orders.append(None) topo_orders.append(None)
elif isinstance(obj, Apply): elif isinstance(obj, Apply):
outputs_to_print.extend(obj.outputs) outputs_to_print.extend(obj.outputs)
profile_list.extend([None for item in obj.outputs]) profile_list.extend(None for item in obj.outputs)
storage_maps.extend([None for item in obj.outputs]) storage_maps.extend(None for item in obj.outputs)
topo_orders.extend([None for item in obj.outputs]) topo_orders.extend(None for item in obj.outputs)
elif isinstance(obj, Function): elif isinstance(obj, Function):
if print_fgraph_inputs: if print_fgraph_inputs:
inputs_to_print.extend(obj.maker.fgraph.inputs) inputs_to_print.extend(obj.maker.fgraph.inputs)
outputs_to_print.extend(obj.maker.fgraph.outputs) outputs_to_print.extend(obj.maker.fgraph.outputs)
profile_list.extend([obj.profile for item in obj.maker.fgraph.outputs]) profile_list.extend(obj.profile for item in obj.maker.fgraph.outputs)
if print_storage: if print_storage:
storage_maps.extend( storage_maps.extend(
[obj.vm.storage_map for item in obj.maker.fgraph.outputs] obj.vm.storage_map for item in obj.maker.fgraph.outputs
) )
else: else:
storage_maps.extend([None for item in obj.maker.fgraph.outputs]) storage_maps.extend(None for item in obj.maker.fgraph.outputs)
topo = obj.maker.fgraph.toposort() topo = obj.maker.fgraph.toposort()
topo_orders.extend([topo for item in obj.maker.fgraph.outputs]) topo_orders.extend(topo for item in obj.maker.fgraph.outputs)
elif isinstance(obj, FunctionGraph): elif isinstance(obj, FunctionGraph):
if print_fgraph_inputs: if print_fgraph_inputs:
inputs_to_print.extend(obj.inputs) inputs_to_print.extend(obj.inputs)
outputs_to_print.extend(obj.outputs) outputs_to_print.extend(obj.outputs)
profile_list.extend([getattr(obj, "profile", None) for item in obj.outputs]) profile_list.extend(getattr(obj, "profile", None) for item in obj.outputs)
storage_maps.extend( storage_maps.extend(
[getattr(obj, "storage_map", None) for item in obj.outputs] getattr(obj, "storage_map", None) for item in obj.outputs
) )
topo = obj.toposort() topo = obj.toposort()
topo_orders.extend([topo for item in obj.outputs]) topo_orders.extend(topo for item in obj.outputs)
elif isinstance(obj, int | float | np.ndarray): elif isinstance(obj, int | float | np.ndarray):
print(obj, file=_file) print(obj, file=_file)
elif isinstance(obj, In | Out): elif isinstance(obj, In | Out):
...@@ -980,10 +980,10 @@ class FunctionPrinter(Printer): ...@@ -980,10 +980,10 @@ class FunctionPrinter(Printer):
name = self.names[idx] name = self.names[idx]
with set_precedence(pstate): with set_precedence(pstate):
inputs_str = ", ".join( inputs_str = ", ".join(
[pprinter.process(input, pstate) for input in node.inputs] pprinter.process(input, pstate) for input in node.inputs
) )
keywords_str = ", ".join( keywords_str = ", ".join(
[f"{kw}={getattr(node.op, kw)}" for kw in self.keywords] f"{kw}={getattr(node.op, kw)}" for kw in self.keywords
) )
if keywords_str and inputs_str: if keywords_str and inputs_str:
...@@ -1050,7 +1050,7 @@ class DefaultPrinter(Printer): ...@@ -1050,7 +1050,7 @@ class DefaultPrinter(Printer):
with set_precedence(pstate): with set_precedence(pstate):
r = "{}({})".format( r = "{}({})".format(
str(node.op), str(node.op),
", ".join([pprinter.process(input, pstate) for input in node.inputs]), ", ".join(pprinter.process(input, pstate) for input in node.inputs),
) )
pstate.memo[output] = r pstate.memo[output] = r
......
...@@ -4224,8 +4224,8 @@ class Composite(ScalarInnerGraphOp): ...@@ -4224,8 +4224,8 @@ class Composite(ScalarInnerGraphOp):
inputs, outputs = res[0], res2[1] inputs, outputs = res[0], res2[1]
self.inputs, self.outputs = self._cleanup_graph(inputs, outputs) self.inputs, self.outputs = self._cleanup_graph(inputs, outputs)
self.inputs_type = tuple([input.type for input in self.inputs]) self.inputs_type = tuple(input.type for input in self.inputs)
self.outputs_type = tuple([output.type for output in self.outputs]) self.outputs_type = tuple(output.type for output in self.outputs)
self.nin = len(inputs) self.nin = len(inputs)
self.nout = len(outputs) self.nout = len(outputs)
super().__init__() super().__init__()
...@@ -4247,7 +4247,7 @@ class Composite(ScalarInnerGraphOp): ...@@ -4247,7 +4247,7 @@ class Composite(ScalarInnerGraphOp):
if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10: if len(self.fgraph.outputs) > 1 or len(self.fgraph.apply_nodes) > 10:
self._name = "Composite{...}" self._name = "Composite{...}"
else: else:
outputs_str = ", ".join([pprint(output) for output in self.fgraph.outputs]) outputs_str = ", ".join(pprint(output) for output in self.fgraph.outputs)
self._name = f"Composite{{{outputs_str}}}" self._name = f"Composite{{{outputs_str}}}"
return self._name return self._name
...@@ -4295,7 +4295,7 @@ class Composite(ScalarInnerGraphOp): ...@@ -4295,7 +4295,7 @@ class Composite(ScalarInnerGraphOp):
return self.outputs_type return self.outputs_type
def make_node(self, *inputs): def make_node(self, *inputs):
if tuple([i.type for i in self.inputs]) == tuple([i.type for i in inputs]): if tuple(i.type for i in self.inputs) == tuple(i.type for i in inputs):
return super().make_node(*inputs) return super().make_node(*inputs)
else: else:
# Make a new op with the right input type. # Make a new op with the right input type.
......
...@@ -160,7 +160,7 @@ class ScalarLoop(ScalarInnerGraphOp): ...@@ -160,7 +160,7 @@ class ScalarLoop(ScalarInnerGraphOp):
f"Got {n_steps.type.dtype}", f"Got {n_steps.type.dtype}",
) )
if self.inputs_type == tuple([i.type for i in inputs]): if self.inputs_type == tuple(i.type for i in inputs):
return super().make_node(n_steps, *inputs) return super().make_node(n_steps, *inputs)
else: else:
# Make a new op with the right input types. # Make a new op with the right input types.
......
...@@ -1936,7 +1936,7 @@ class ScanMerge(GraphRewriter): ...@@ -1936,7 +1936,7 @@ class ScanMerge(GraphRewriter):
profile=old_op.profile, profile=old_op.profile,
truncate_gradient=old_op.truncate_gradient, truncate_gradient=old_op.truncate_gradient,
allow_gc=old_op.allow_gc, allow_gc=old_op.allow_gc,
name="&".join([nd.op.name for nd in nodes]), name="&".join(nd.op.name for nd in nodes),
) )
new_outs = new_op(*outer_ins) new_outs = new_op(*outer_ins)
......
...@@ -749,15 +749,13 @@ class ScanArgs: ...@@ -749,15 +749,13 @@ class ScanArgs:
def field_names(self): def field_names(self):
res = ["mit_mot_out_slices", "mit_mot_in_slices", "mit_sot_in_slices"] res = ["mit_mot_out_slices", "mit_mot_in_slices", "mit_sot_in_slices"]
res.extend( res.extend(
[ attr
attr for attr in self.__dict__
for attr in self.__dict__ if attr.startswith("inner_in")
if attr.startswith("inner_in") or attr.startswith("inner_out")
or attr.startswith("inner_out") or attr.startswith("outer_in")
or attr.startswith("outer_in") or attr.startswith("outer_out")
or attr.startswith("outer_out") or attr == "n_steps"
or attr == "n_steps"
]
) )
return res return res
......
...@@ -1554,7 +1554,7 @@ class Alloc(COp): ...@@ -1554,7 +1554,7 @@ class Alloc(COp):
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
(out,) = out_ (out,) = out_
v = inputs[0] v = inputs[0]
sh = tuple([int(i) for i in inputs[1:]]) sh = tuple(int(i) for i in inputs[1:])
self._check_runtime_broadcast(node, v, sh) self._check_runtime_broadcast(node, v, sh)
if out[0] is None or out[0].shape != sh: if out[0] is None or out[0].shape != sh:
...@@ -4180,7 +4180,7 @@ class AllocEmpty(COp): ...@@ -4180,7 +4180,7 @@ class AllocEmpty(COp):
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
(out,) = out_ (out,) = out_
sh = tuple([int(i) for i in inputs]) sh = tuple(int(i) for i in inputs)
if out[0] is None or out[0].shape != sh: if out[0] is None or out[0].shape != sh:
out[0] = np.empty(sh, dtype=self.dtype) out[0] = np.empty(sh, dtype=self.dtype)
......
...@@ -1691,7 +1691,7 @@ class BatchedDot(COp): ...@@ -1691,7 +1691,7 @@ class BatchedDot(COp):
if x.shape[0] != y.shape[0]: if x.shape[0] != y.shape[0]:
raise TypeError( raise TypeError(
f"Inputs [{', '.join(map(str, inp))}] must have the" f"Inputs [{', '.join(map(str, inp))}] must have the"
f" same size in axis 0, but have sizes [{', '.join([str(i.shape[0]) for i in inp])}]." f" same size in axis 0, but have sizes [{', '.join(str(i.shape[0]) for i in inp)}]."
) )
z[0] = np.matmul(x, y) z[0] = np.matmul(x, y)
......
...@@ -139,10 +139,8 @@ class Blockwise(Op): ...@@ -139,10 +139,8 @@ class Blockwise(Op):
try: try:
batch_shape = tuple( batch_shape = tuple(
[ broadcast_static_dim_lengths(batch_dims)
broadcast_static_dim_lengths(batch_dims) for batch_dims in zip(*batch_shapes)
for batch_dims in zip(*batch_shapes)
]
) )
except ValueError: except ValueError:
raise ValueError( raise ValueError(
......
...@@ -182,7 +182,7 @@ class DimShuffle(ExternalCOp): ...@@ -182,7 +182,7 @@ class DimShuffle(ExternalCOp):
self.transposition = self.shuffle + drop self.transposition = self.shuffle + drop
# List of dimensions of the output that are broadcastable and were not # List of dimensions of the output that are broadcastable and were not
# in the original input # in the original input
self.augment = sorted([i for i, x in enumerate(new_order) if x == "x"]) self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.drop = drop self.drop = drop
if self.inplace: if self.inplace:
...@@ -893,11 +893,9 @@ class Elemwise(OpenMPOp): ...@@ -893,11 +893,9 @@ class Elemwise(OpenMPOp):
# In that case, create a fortran output ndarray. # In that case, create a fortran output ndarray.
z = list(zip(inames, inputs)) z = list(zip(inames, inputs))
alloc_fortran = " && ".join( alloc_fortran = " && ".join(
[ f"PyArray_ISFORTRAN({arr})"
f"PyArray_ISFORTRAN({arr})" for arr, var in z
for arr, var in z if not all(s == 1 for s in var.type.shape)
if not all(s == 1 for s in var.type.shape)
]
) )
# If it is a scalar, make it c contig to prevent problem with # If it is a scalar, make it c contig to prevent problem with
# NumPy C and F contig not always set as both of them. # NumPy C and F contig not always set as both of them.
...@@ -984,12 +982,10 @@ class Elemwise(OpenMPOp): ...@@ -984,12 +982,10 @@ class Elemwise(OpenMPOp):
if len(all_code) == 1: if len(all_code) == 1:
# No loops # No loops
task_decl = "".join( task_decl = "".join(
[ f"{dtype}& {name}_i = *{name}_iter;\n"
f"{dtype}& {name}_i = *{name}_iter;\n" for name, dtype in zip(
for name, dtype in zip( inames + list(real_onames), idtypes + list(real_odtypes)
inames + list(real_onames), idtypes + list(real_odtypes) )
)
]
) )
preloops = {} preloops = {}
...@@ -1101,18 +1097,14 @@ class Elemwise(OpenMPOp): ...@@ -1101,18 +1097,14 @@ class Elemwise(OpenMPOp):
z = list(zip(inames + onames, inputs + node.outputs)) z = list(zip(inames + onames, inputs + node.outputs))
all_broadcastable = all(s == 1 for s in var.type.shape) all_broadcastable = all(s == 1 for s in var.type.shape)
cond1 = " && ".join( cond1 = " && ".join(
[ f"PyArray_ISCONTIGUOUS({arr})"
f"PyArray_ISCONTIGUOUS({arr})" for arr, var in z
for arr, var in z if not all_broadcastable
if not all_broadcastable
]
) )
cond2 = " && ".join( cond2 = " && ".join(
[ f"PyArray_ISFORTRAN({arr})"
f"PyArray_ISFORTRAN({arr})" for arr, var in z
for arr, var in z if not all_broadcastable
if not all_broadcastable
]
) )
loop = """ loop = """
if(({cond1}) || ({cond2})){{ if(({cond1}) || ({cond2})){{
......
...@@ -1248,7 +1248,7 @@ class Unique(Op): ...@@ -1248,7 +1248,7 @@ class Unique(Op):
f"Unique axis `{self.axis}` is outside of input ndim = {ndim}." f"Unique axis `{self.axis}` is outside of input ndim = {ndim}."
) )
ret[0] = tuple( ret[0] = tuple(
[fgraph.shape_feature.shape_ir(i, node.outputs[0]) for i in range(ndim)] fgraph.shape_feature.shape_ir(i, node.outputs[0]) for i in range(ndim)
) )
if self.return_inverse: if self.return_inverse:
if self.axis is None: if self.axis is None:
......
...@@ -258,11 +258,9 @@ class Argmax(COp): ...@@ -258,11 +258,9 @@ class Argmax(COp):
if self.axis is None: if self.axis is None:
return [()] return [()]
rval = tuple( rval = tuple(
[ ishape[i]
ishape[i] for (i, b) in enumerate(node.inputs[0].type.broadcastable)
for (i, b) in enumerate(node.inputs[0].type.broadcastable) if i not in self.axis
if i not in self.axis
]
) )
return [rval] return [rval]
......
...@@ -800,10 +800,8 @@ class Reshape(COp): ...@@ -800,10 +800,8 @@ class Reshape(COp):
rest_size = input_size // maximum(requ_size, 1) rest_size = input_size // maximum(requ_size, 1)
return [ return [
tuple( tuple(
[ ptb.switch(eq(requ[i], -1), rest_size, requ[i])
ptb.switch(eq(requ[i], -1), rest_size, requ[i]) for i in range(self.ndim)
for i in range(self.ndim)
]
) )
] ]
......
...@@ -879,7 +879,7 @@ class BaseBlockDiagonal(Op): ...@@ -879,7 +879,7 @@ class BaseBlockDiagonal(Op):
__props__ = ("n_inputs",) __props__ = ("n_inputs",)
def __init__(self, n_inputs): def __init__(self, n_inputs):
input_sig = ",".join([f"(m{i},n{i})" for i in range(n_inputs)]) input_sig = ",".join(f"(m{i},n{i})" for i in range(n_inputs))
self.gufunc_signature = f"{input_sig}->(m,n)" self.gufunc_signature = f"{input_sig}->(m,n)"
if n_inputs == 0: if n_inputs == 0:
......
...@@ -1113,7 +1113,7 @@ class Subtensor(COp): ...@@ -1113,7 +1113,7 @@ class Subtensor(COp):
if is_slice: if is_slice:
is_slice_init = ( is_slice_init = (
"int is_slice[] = {" + ",".join([str(s) for s in is_slice]) + "};" "int is_slice[] = {" + ",".join(str(s) for s in is_slice) + "};"
) )
else: else:
is_slice_init = "int* is_slice = NULL;" is_slice_init = "int* is_slice = NULL;"
...@@ -2401,9 +2401,7 @@ class AdvancedIncSubtensor1(COp): ...@@ -2401,9 +2401,7 @@ class AdvancedIncSubtensor1(COp):
fn_array = ( fn_array = (
"static inplace_map_binop addition_funcs[] = {" "static inplace_map_binop addition_funcs[] = {"
+ "".join( + "".join(gen_binop(type=t, typen=t.upper()) for t in types + complex_types)
[gen_binop(type=t, typen=t.upper()) for t in types + complex_types]
)
+ "NULL};\n" + "NULL};\n"
) )
...@@ -2416,7 +2414,7 @@ class AdvancedIncSubtensor1(COp): ...@@ -2416,7 +2414,7 @@ class AdvancedIncSubtensor1(COp):
type_number_array = ( type_number_array = (
"static int type_numbers[] = {" "static int type_numbers[] = {"
+ "".join([gen_num(typen=t.upper()) for t in types + complex_types]) + "".join(gen_num(typen=t.upper()) for t in types + complex_types)
+ "-1000};" + "-1000};"
) )
......
...@@ -401,7 +401,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape): ...@@ -401,7 +401,7 @@ class TensorType(CType[np.ndarray], HasDataType, HasShape):
else: else:
return str(s) return str(s)
formatted_shape = ", ".join([shape_str(s) for s in shape]) formatted_shape = ", ".join(shape_str(s) for s in shape)
if len_shape == 1: if len_shape == 1:
formatted_shape += "," formatted_shape += ","
......
...@@ -521,12 +521,10 @@ class _tensor_py_operators: ...@@ -521,12 +521,10 @@ class _tensor_py_operators:
# Else leave it as is if it is a real number # Else leave it as is if it is a real number
# Convert python literals to pytensor constants # Convert python literals to pytensor constants
args = tuple( args = tuple(
[ pt.subtensor.as_index_constant(
pt.subtensor.as_index_constant( np.array(inp, dtype=np.uint8) if is_empty_array(inp) else inp
np.array(inp, dtype=np.uint8) if is_empty_array(inp) else inp )
) for inp in args
for inp in args
]
) )
# Determine if advanced indexing is needed or not. The logic is # Determine if advanced indexing is needed or not. The logic is
......
...@@ -3418,7 +3418,7 @@ class TestSumMeanMaxMinArgMaxVarReduceAxes: ...@@ -3418,7 +3418,7 @@ class TestSumMeanMaxMinArgMaxVarReduceAxes:
def reduce_bitwise_and(x, axis=-1, dtype="int8"): def reduce_bitwise_and(x, axis=-1, dtype="int8"):
identity = np.array((-1,), dtype=dtype)[0] identity = np.array((-1,), dtype=dtype)[0]
shape_without_axis = tuple([s for i, s in enumerate(x.shape) if i != axis]) shape_without_axis = tuple(s for i, s in enumerate(x.shape) if i != axis)
if 0 in shape_without_axis: if 0 in shape_without_axis:
return np.empty(shape=shape_without_axis, dtype=x.dtype) return np.empty(shape=shape_without_axis, dtype=x.dtype)
......
...@@ -667,7 +667,7 @@ def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs): ...@@ -667,7 +667,7 @@ def makeBroadcastTester(op, expected, checks=None, name=None, **kwargs):
# For instance: sub_inplace -> SubInplace # For instance: sub_inplace -> SubInplace
capitalize = True capitalize = True
if capitalize: if capitalize:
name = "".join([x.capitalize() for x in name.split("_")]) name = "".join(x.capitalize() for x in name.split("_"))
# Some tests specify a name that already ends with 'Tester', while in other # Some tests specify a name that already ends with 'Tester', while in other
# cases we need to add it manually. # cases we need to add it manually.
if not name.endswith("Tester"): if not name.endswith("Tester"):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论