提交 8ae2a195 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Remove dict.keys() when unnecessary

上级 63da6d16
......@@ -659,7 +659,7 @@ class Function:
exist_svs = [i.variable for i in maker.inputs]
# Check if given ShareVariables exist
for sv in swap.keys():
for sv in swap:
if sv not in exist_svs:
raise ValueError(f"SharedVariable: {sv.name} not found")
......@@ -711,9 +711,9 @@ class Function:
# it is well tested, we don't share the part of the storage_map.
if share_memory:
i_o_vars = maker.fgraph.inputs + maker.fgraph.outputs
for key in storage_map.keys():
for key, val in storage_map.items():
if key not in i_o_vars:
new_storage_map[memo[key]] = storage_map[key]
new_storage_map[memo[key]] = val
if not name and self.name:
name = self.name + " copy"
......@@ -1446,7 +1446,7 @@ class FunctionMaker:
if not hasattr(mode.linker, "accept"):
raise ValueError(
"'linker' parameter of FunctionMaker should be "
f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers.keys())}"
f"a Linker with an accept method or one of {list(pytensor.compile.mode.predefined_linkers)}"
)
def __init__(
......
......@@ -1446,7 +1446,7 @@ class ProfileStats:
file=file,
)
if config.profiling__debugprint:
fcts = {fgraph for (fgraph, n) in self.apply_time.keys()}
fcts = {fgraph for (fgraph, n) in self.apply_time}
pytensor.printing.debugprint(fcts, print_type=True)
if self.variable_shape or self.variable_strides:
self.summary_memory(file, n_apply_to_print)
......
......@@ -1318,7 +1318,7 @@ def add_caching_dir_configvars():
_compiledir_format_dict["short_platform"] = short_platform()
# Allow to have easily one compiledir per device.
_compiledir_format_dict["device"] = config.device
compiledir_format_keys = ", ".join(sorted(_compiledir_format_dict.keys()))
compiledir_format_keys = ", ".join(sorted(_compiledir_format_dict))
_default_compiledir_format = (
"compiledir_%(short_platform)s-%(processor)s-"
"%(python_version)s-%(python_bitwidth)s"
......
......@@ -214,7 +214,7 @@ class PyTensorConfigParser:
return _ChangeFlagsDecorator(*args, _root=self, **kwargs)
def warn_unused_flags(self):
for key in self._flags_dict.keys():
for key in self._flags_dict:
warnings.warn(f"PyTensor does not recognise this flag: {key}")
......
......@@ -500,7 +500,7 @@ def grad(
if cost is not None:
outputs.append(cost)
if known_grads is not None:
outputs.extend(list(known_grads.keys()))
outputs.extend(list(known_grads))
var_to_app_to_idx = _populate_var_to_app_to_idx(outputs, _wrt, consider_constant)
......@@ -966,7 +966,7 @@ def _populate_var_to_app_to_idx(outputs, wrt, consider_constant):
visit(elem)
# Remove variables that don't have wrt as a true ancestor
orig_vars = list(var_to_app_to_idx.keys())
orig_vars = list(var_to_app_to_idx)
for var in orig_vars:
if var not in visited:
del var_to_app_to_idx[var]
......
......@@ -631,7 +631,7 @@ class Variable(Node, Generic[_TypeType, OptionalApplyType]):
if not hasattr(self, "_fn_cache"):
self._fn_cache: dict = dict()
inputs = tuple(sorted(parsed_inputs_to_values.keys(), key=id))
inputs = tuple(sorted(parsed_inputs_to_values, key=id))
cache_key = (inputs, tuple(kwargs.items()))
try:
fn = self._fn_cache[cache_key]
......
......@@ -406,7 +406,7 @@ class DestroyHandler(Bookkeeper):
# If True means that the apply node, destroys the protected_var.
if idx in [dmap for sublist in destroy_maps for dmap in sublist]:
return True
for var_idx in app.op.view_map.keys():
for var_idx in app.op.view_map:
if idx in app.op.view_map[var_idx]:
# We need to recursively check the destroy_map of all the
# outputs that we have a view_map on.
......
......@@ -15,7 +15,7 @@ from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable as IterableType
from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain
from typing import TYPE_CHECKING, Literal, cast
from typing import TYPE_CHECKING, Literal
import pytensor
from pytensor.configdefaults import config
......@@ -1924,9 +1924,9 @@ class NodeProcessingGraphRewriter(GraphRewriter):
remove: list[Variable] = []
if isinstance(replacements, dict):
if "remove" in replacements:
remove = list(cast(Sequence[Variable], replacements.pop("remove")))
old_vars = list(cast(Sequence[Variable], replacements.keys()))
replacements = list(cast(Sequence[Variable], replacements.values()))
remove = list(replacements.pop("remove"))
old_vars = list(replacements)
replacements = list(replacements.values())
elif not isinstance(replacements, tuple | list):
raise TypeError(
f"Node rewriter {node_rewriter} gave wrong type of replacement. "
......
......@@ -168,7 +168,7 @@ class MissingInputError(Exception):
def __init__(self, *args, **kwargs):
if kwargs:
# The call to list is needed for Python 3
assert list(kwargs.keys()) == ["variable"]
assert list(kwargs) == ["variable"]
error_msg = get_variable_trace_string(kwargs["variable"])
if error_msg:
args = (*args, error_msg)
......
......@@ -264,10 +264,7 @@ class Params(dict):
def __repr__(self):
return "Params({})".format(
", ".join(
[
(f"{k}:{type(self[k]).__name__}:{self[k]}")
for k in sorted(self.keys())
]
[(f"{k}:{type(self[k]).__name__}:{self[k]}") for k in sorted(self)]
)
)
......@@ -365,7 +362,7 @@ class ParamsType(CType):
)
self.length = len(kwargs)
self.fields = tuple(sorted(kwargs.keys()))
self.fields = tuple(sorted(kwargs))
self.types = tuple(kwargs[field] for field in self.fields)
self.name = self.generate_struct_name()
......
......@@ -472,7 +472,7 @@ class EnumType(CType, dict):
"""
Return the sorted tuple of all aliases in this enumeration.
"""
return tuple(sorted(self.aliases.keys()))
return tuple(sorted(self.aliases))
def __repr__(self):
names_to_aliases = {constant_name: "" for constant_name in self}
......@@ -481,9 +481,7 @@ class EnumType(CType, dict):
return "{}<{}>({})".format(
type(self).__name__,
self.ctype,
", ".join(
f"{k}{names_to_aliases[k]}:{self[k]}" for k in sorted(self.keys())
),
", ".join(f"{k}{names_to_aliases[k]}:{self[k]}" for k in sorted(self)),
)
def __getattr__(self, key):
......@@ -612,7 +610,7 @@ class EnumType(CType, dict):
f"""
#define {k} {self[k]!s}
"""
for k in sorted(self.keys())
for k in sorted(self)
)
+ self.c_to_string()
)
......@@ -772,7 +770,7 @@ class CEnumType(EnumList):
case %(i)d: %(name)s = %(constant_cname)s; break;
"""
% dict(i=i, name=name, constant_cname=swapped_dict[i])
for i in sorted(swapped_dict.keys())
for i in sorted(swapped_dict)
),
fail=sub["fail"],
)
......
......@@ -117,9 +117,7 @@ def {scalar_op_fn_name}({input_names}):
converted_call_args = ", ".join(
[
f"direct_cast({i_name}, {i_tmp_dtype_name})"
for i_name, i_tmp_dtype_name in zip(
input_names, input_tmp_dtype_names.keys()
)
for i_name, i_tmp_dtype_name in zip(input_names, input_tmp_dtype_names)
]
)
if not has_pyx_skip_dispatch:
......
......@@ -70,7 +70,7 @@ def numba_funcify_Scan(op, node, **kwargs):
outer_in_names_to_vars = {
(f"outer_in_{i}" if i > 0 else "n_steps"): v for i, v in enumerate(node.inputs)
}
outer_in_names = list(outer_in_names_to_vars.keys())
outer_in_names = list(outer_in_names_to_vars)
outer_in_seqs_names = op.outer_seqs(outer_in_names)
outer_in_mit_mot_names = op.outer_mitmot(outer_in_names)
outer_in_mit_sot_names = op.outer_mitsot(outer_in_names)
......
......@@ -990,7 +990,7 @@ class VMLinker(LocalLinker):
for pair in reallocated_info.values():
storage_map[pair[1]] = storage_map[pair[0]]
return tuple(reallocated_info.keys())
return tuple(reallocated_info)
def make_vm(
self,
......
......@@ -1106,9 +1106,7 @@ class PPrinter(Printer):
outputs = [outputs]
current = None
if display_inputs:
strings = [
(0, "inputs: " + ", ".join(map(str, list(inputs) + updates.keys())))
]
strings = [(0, "inputs: " + ", ".join(str(x) for x in [*inputs, *updates]))]
else:
strings = []
pprinter = self.clone_assign(
......@@ -1116,9 +1114,7 @@ class PPrinter(Printer):
)
inv_updates = {b: a for (a, b) in updates.items()}
i = 1
for node in io_toposort(
list(inputs) + updates.keys(), list(outputs) + updates.values()
):
for node in io_toposort([*inputs, *updates], [*outputs, *updates.values()]):
for output in node.outputs:
if output in inv_updates:
name = str(inv_updates[output])
......
......@@ -4426,7 +4426,7 @@ class Compositef32:
else:
ni = i
mapping[i] = ni
if isinstance(node.op, tuple(self.special.keys())):
if isinstance(node.op, tuple(self.special)):
self.special[type(node.op)](node, mapping)
continue
new_node = node.clone_with_new_inputs(
......
......@@ -1284,14 +1284,14 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
def __str__(self):
inplace = "none"
if len(self.destroy_map.keys()) > 0:
if self.destroy_map:
# Check if all outputs are inplace
if sorted(self.destroy_map.keys()) == sorted(
if sorted(self.destroy_map) == sorted(
range(self.info.n_mit_mot + self.info.n_mit_sot + self.info.n_sit_sot)
):
inplace = "all"
else:
inplace = str(list(self.destroy_map.keys()))
inplace = str(list(self.destroy_map))
return (
f"Scan{{{self.name}, while_loop={self.info.as_while}, inplace={inplace}}}"
)
......
......@@ -269,7 +269,7 @@ class Validator:
# Mapping from invalid variables to equivalent valid ones.
self.valid_equivalent = valid_equivalent.copy()
self.valid.update(list(valid_equivalent.values()))
self.invalid.update(list(valid_equivalent.keys()))
self.invalid.update(list(valid_equivalent))
def check(self, out):
"""
......
......@@ -524,7 +524,7 @@ csc_fmatrix = SparseTensorType(format="csc", dtype="float32")
csr_fmatrix = SparseTensorType(format="csr", dtype="float32")
bsr_fmatrix = SparseTensorType(format="bsr", dtype="float32")
all_dtypes = list(SparseTensorType.dtype_specs_map.keys())
all_dtypes = list(SparseTensorType.dtype_specs_map)
complex_dtypes = [t for t in all_dtypes if t[:7] == "complex"]
float_dtypes = [t for t in all_dtypes if t[:5] == "float"]
int_dtypes = [t for t in all_dtypes if t[:3] == "int"]
......
......@@ -71,7 +71,7 @@ class InplaceElemwiseOptimizer(GraphRewriter):
ndim = prof["ndim"]
if ndim:
print(blanc, "ndim", "nb", file=stream)
for n in sorted(ndim.keys()):
for n in sorted(ndim):
print(blanc, n, ndim[n], file=stream)
def candidate_input_idxs(self, node):
......
......@@ -88,7 +88,7 @@ def shape_of_variables(
compute_shapes = pytensor.function(input_dims, output_dims)
if any(i not in fgraph.inputs for i in input_shapes.keys()):
if any(i not in fgraph.inputs for i in input_shapes):
raise ValueError(
"input_shapes keys aren't in the fgraph.inputs. FunctionGraph()"
" interface changed. Now by default, it clones the graph it receives."
......
......@@ -889,9 +889,9 @@ class TestPicklefunction:
return
else:
raise
# if they both return, assume that they return equivalent things.
# print [(k,id(k)) for k in f.finder.keys()]
# print [(k,id(k)) for k in g.finder.keys()]
# if they both return, assume that they return equivalent things.
# print [(k, id(k)) for k in f.finder]
# print [(k, id(k)) for k in g.finder]
assert g.container[0].storage is not f.container[0].storage
assert g.container[1].storage is not f.container[1].storage
......@@ -1012,9 +1012,9 @@ class TestPicklefunction:
return
else:
raise
# if they both return, assume that they return equivalent things.
# print [(k,id(k)) for k in f.finder.keys()]
# print [(k,id(k)) for k in g.finder.keys()]
# if they both return, assume that they return equivalent things.
# print [(k, id(k)) for k in f.finder]
# print [(k, id(k)) for k in g.finder]
assert g.container[0].storage is not f.container[0].storage
assert g.container[1].storage is not f.container[1].storage
......
......@@ -829,7 +829,7 @@ def test_config_options_fastmath():
with config.change_flags(numba__fastmath=True):
pytensor_numba_fn = function([x], pt.sum(x), mode=numba_mode)
print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__.keys()))
print(list(pytensor_numba_fn.vm.jit_fn.py_func.__globals__))
numba_mul_fn = pytensor_numba_fn.vm.jit_fn.py_func.__globals__["impl_sum"]
assert numba_mul_fn.targetoptions["fastmath"] is True
......
......@@ -479,14 +479,14 @@ def test_vector_taps_benchmark(benchmark):
sitsot_init: rng.normal(),
}
numba_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("NUMBA"))
numba_fn = pytensor.function(list(test), outs, mode=get_mode("NUMBA"))
scan_nodes = [
node for node in numba_fn.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
assert len(scan_nodes) == 1
numba_res = numba_fn(*test.values())
ref_fn = pytensor.function(list(test.keys()), outs, mode=get_mode("FAST_COMPILE"))
ref_fn = pytensor.function(list(test), outs, mode=get_mode("FAST_COMPILE"))
ref_res = ref_fn(*test.values())
for numba_r, ref_r in zip(numba_res, ref_res):
np.testing.assert_array_almost_equal(numba_r, ref_r)
......
......@@ -57,7 +57,7 @@ def test_fgraph_to_python_names():
"scalar_variable",
"tensor_variable_1",
r.name,
) == tuple(sig.parameters.keys())
) == tuple(sig.parameters)
assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5)
obj = object()
......
......@@ -337,7 +337,7 @@ def test_reallocation():
def check_storage(storage_map):
for i in storage_map:
if not isinstance(i, TensorConstant):
keys_copy = list(storage_map.keys())[:]
keys_copy = list(storage_map)[:]
keys_copy.remove(i)
for o in keys_copy:
if storage_map[i][0] and storage_map[i][0] is storage_map[o][0]:
......
......@@ -1097,8 +1097,8 @@ class TestScanInplaceOptimizer:
allow_input_downcast=True,
)
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert 0 in scan_node[0].op.destroy_map.keys()
assert 1 in scan_node[0].op.destroy_map.keys()
assert 0 in scan_node[0].op.destroy_map
assert 1 in scan_node[0].op.destroy_map
# compute output in numpy
numpy_x0 = np.zeros((3,))
numpy_x1 = np.zeros((3,))
......@@ -1163,8 +1163,8 @@ class TestScanInplaceOptimizer:
)
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert 0 in scan_node[0].op.destroy_map.keys()
assert 1 in scan_node[0].op.destroy_map.keys()
assert 0 in scan_node[0].op.destroy_map
assert 1 in scan_node[0].op.destroy_map
# compute output in numpy
numpy_x0 = np.zeros((3,))
numpy_x1 = np.zeros((3,))
......@@ -1203,8 +1203,8 @@ class TestScanInplaceOptimizer:
f9 = function([], outputs, updates=updates, mode=self.mode)
scan_node = [x for x in f9.maker.fgraph.toposort() if isinstance(x.op, Scan)]
assert 0 not in scan_node[0].op.destroy_map.keys()
assert 1 in scan_node[0].op.destroy_map.keys()
assert 0 not in scan_node[0].op.destroy_map
assert 1 in scan_node[0].op.destroy_map
class TestSaveMem:
......
......@@ -222,7 +222,7 @@ def test_config_pickling():
buffer.seek(0)
restored = pickle.load(buffer)
# ...without a change in the config values
for name in root._config_var_dict.keys():
for name in root._config_var_dict:
v_original = getattr(root, name)
v_restored = getattr(restored, name)
assert (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论