提交 c8839c8b authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Fix RUF015

上级 d3dd34e7
......@@ -1038,7 +1038,7 @@ class ModuleCache:
_logger.info(f"deleting ModuleCache entry {entry}")
key_data.delete_keys_from(self.entry_from_key)
del self.module_hash_to_key_data[module_hash]
if key_data.keys and list(key_data.keys)[0][0]:
if key_data.keys and next(iter(key_data.keys))[0]:
# this is a versioned entry, so should have been on
# disk. Something weird happened to cause this, so we
# are responding by printing a warning, removing
......
......@@ -151,7 +151,7 @@ def broadcast_static_dim_lengths(
dim_lengths_set = set(dim_lengths)
# All dim_lengths are the same
if len(dim_lengths_set) == 1:
return tuple(dim_lengths_set)[0]
return next(iter(dim_lengths_set))
# Only valid indeterminate case
if dim_lengths_set == {None, 1}:
......@@ -161,7 +161,7 @@ def broadcast_static_dim_lengths(
dim_lengths_set.discard(None)
if len(dim_lengths_set) > 1:
raise ValueError
return tuple(dim_lengths_set)[0]
return next(iter(dim_lengths_set))
# Copied verbatim from numpy.lib.function_base
......
......@@ -275,7 +275,7 @@ def test_allow_gc_cvm():
f = function([v], v + 1, mode=mode)
f([1])
n = list(f.maker.fgraph.apply_nodes)[0].outputs[0]
n = next(iter(f.maker.fgraph.apply_nodes)).outputs[0]
assert f.vm.storage_map[n][0] is None
assert f.vm.allow_gc is True
......
......@@ -1630,7 +1630,7 @@ class TestScan:
# Also validate that the mappings outer_inp_from_outer_out and
# outer_inp_from_inner_inp produce the correct results
scan_node = list(updates.values())[0].owner
scan_node = next(iter(updates.values())).owner
var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
result = var_mappings["outer_inp_from_outer_out"]
......@@ -1922,7 +1922,7 @@ class TestScan:
_, updates = scan(
inner_fn, n_steps=10, truncate_gradient=-1, go_backwards=False
)
cost = list(updates.values())[0]
cost = next(iter(updates.values()))
g_sh = grad(cost, shared_var)
fgrad = function([], g_sh)
assert fgrad() == 1
......
......@@ -270,7 +270,7 @@ class TestPushOutDot:
f = function([h0, W1, W2], o, mode=self.mode)
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
assert (
len(
[
......@@ -444,9 +444,9 @@ class TestPushOutNonSeqScan:
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [
scan_node = next(
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][0]
)
assert len(scan_node.op.inner_outputs) == 1
assert not isinstance(scan_node.op.inner_outputs[0], Dot)
......@@ -488,9 +488,9 @@ class TestPushOutNonSeqScan:
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [
scan_node = next(
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][0]
)
# NOTE: WHEN INFER_SHAPE IS RE-ENABLED, BELOW THE SCAN MUST
# HAVE ONLY 1 OUTPUT.
assert len(scan_node.op.inner_outputs) == 2
......@@ -536,9 +536,9 @@ class TestPushOutNonSeqScan:
# Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should
# not be the result of a Dot
scan_node = [
scan_node = next(
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][0]
)
assert len(scan_node.op.inner_outputs) == 2
assert not isinstance(scan_node.op.inner_outputs[0], Dot)
......@@ -1639,7 +1639,7 @@ def test_alloc_inputs1():
)
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
assert (
len(
[
......@@ -1673,7 +1673,7 @@ def test_alloc_inputs2():
)
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
assert (
len(
......@@ -1709,7 +1709,7 @@ def test_alloc_inputs3():
# TODO FIXME: This result depends on unrelated rewrites in the "fast" mode.
f = function([_h0, _W1, _W2], o, mode="FAST_RUN")
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0]
scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
assert len(scan_node.op.inner_inputs) == 1
......
......@@ -3274,7 +3274,7 @@ class TestLocalReduce:
order = f.maker.fgraph.toposort()
assert 1 == sum(isinstance(node.op, CAReduce) for node in order)
node = [node for node in order if isinstance(node.op, CAReduce)][0]
node = next(node for node in order if isinstance(node.op, CAReduce))
op = node.op
assert isinstance(op, CAReduce)
......
......@@ -75,7 +75,7 @@ def test_merge_with_weird_eq():
MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0]
node = next(iter(g.apply_nodes))
assert len(node.inputs) == 2
assert node.inputs[0] is node.inputs[1]
......@@ -87,6 +87,6 @@ def test_merge_with_weird_eq():
MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0]
node = next(iter(g.apply_nodes))
assert len(node.inputs) == 2
assert node.inputs[0] is node.inputs[1]
......@@ -465,7 +465,7 @@ class TestSpecifyShape(utt.InferShapeTester):
f(xval)
assert isinstance(
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0]
next(n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape))
.inputs[0]
.type,
self.input_type,
......@@ -475,7 +475,7 @@ class TestSpecifyShape(utt.InferShapeTester):
xval = np.random.random((2, 3)).astype(config.floatX)
f = pytensor.function([x], specify_shape(x, 2, 3), mode=self.mode)
assert isinstance(
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0]
next(n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape))
.inputs[0]
.type,
self.input_type,
......
......@@ -194,7 +194,7 @@ class TestIfelse(utt.OptimizationTestMixin):
f = function([c, x1, x2, y1, y2], z, mode=self.mode)
self.assertFunctionContains1(f, self.get_ifelse(2))
ifnode = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse)][0]
ifnode = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse))
assert len(ifnode.outputs) == 2
rng = np.random.default_rng(utt.fetch_seed())
......@@ -369,7 +369,7 @@ class TestIfelse(utt.OptimizationTestMixin):
z = ifelse(c, (x, x), (y, y))
f = function([c, x, y], z)
ifnode = [n for n in f.maker.fgraph.toposort() if isinstance(n.op, IfElse)][0]
ifnode = next(n for n in f.maker.fgraph.toposort() if isinstance(n.op, IfElse))
assert len(ifnode.inputs) == 3
@pytest.mark.skip(reason="Optimization temporarily disabled")
......@@ -382,7 +382,7 @@ class TestIfelse(utt.OptimizationTestMixin):
z = ifelse(c, (x1, x1, x1, x2, x2), (y1, y1, y2, y2, y2))
f = function([c, x1, x2, y1, y2], z)
ifnode = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse)][0]
ifnode = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse))
assert len(ifnode.outputs) == 3
@pytest.mark.skip(reason="Optimization temporarily disabled")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论