提交 c8839c8b authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Fix RUF015

上级 d3dd34e7
...@@ -1038,7 +1038,7 @@ class ModuleCache: ...@@ -1038,7 +1038,7 @@ class ModuleCache:
_logger.info(f"deleting ModuleCache entry {entry}") _logger.info(f"deleting ModuleCache entry {entry}")
key_data.delete_keys_from(self.entry_from_key) key_data.delete_keys_from(self.entry_from_key)
del self.module_hash_to_key_data[module_hash] del self.module_hash_to_key_data[module_hash]
if key_data.keys and list(key_data.keys)[0][0]: if key_data.keys and next(iter(key_data.keys))[0]:
# this is a versioned entry, so should have been on # this is a versioned entry, so should have been on
# disk. Something weird happened to cause this, so we # disk. Something weird happened to cause this, so we
# are responding by printing a warning, removing # are responding by printing a warning, removing
......
...@@ -151,7 +151,7 @@ def broadcast_static_dim_lengths( ...@@ -151,7 +151,7 @@ def broadcast_static_dim_lengths(
dim_lengths_set = set(dim_lengths) dim_lengths_set = set(dim_lengths)
# All dim_lengths are the same # All dim_lengths are the same
if len(dim_lengths_set) == 1: if len(dim_lengths_set) == 1:
return tuple(dim_lengths_set)[0] return next(iter(dim_lengths_set))
# Only valid indeterminate case # Only valid indeterminate case
if dim_lengths_set == {None, 1}: if dim_lengths_set == {None, 1}:
...@@ -161,7 +161,7 @@ def broadcast_static_dim_lengths( ...@@ -161,7 +161,7 @@ def broadcast_static_dim_lengths(
dim_lengths_set.discard(None) dim_lengths_set.discard(None)
if len(dim_lengths_set) > 1: if len(dim_lengths_set) > 1:
raise ValueError raise ValueError
return tuple(dim_lengths_set)[0] return next(iter(dim_lengths_set))
# Copied verbatim from numpy.lib.function_base # Copied verbatim from numpy.lib.function_base
......
...@@ -275,7 +275,7 @@ def test_allow_gc_cvm(): ...@@ -275,7 +275,7 @@ def test_allow_gc_cvm():
f = function([v], v + 1, mode=mode) f = function([v], v + 1, mode=mode)
f([1]) f([1])
n = list(f.maker.fgraph.apply_nodes)[0].outputs[0] n = next(iter(f.maker.fgraph.apply_nodes)).outputs[0]
assert f.vm.storage_map[n][0] is None assert f.vm.storage_map[n][0] is None
assert f.vm.allow_gc is True assert f.vm.allow_gc is True
......
...@@ -1630,7 +1630,7 @@ class TestScan: ...@@ -1630,7 +1630,7 @@ class TestScan:
# Also validate that the mappings outer_inp_from_outer_out and # Also validate that the mappings outer_inp_from_outer_out and
# outer_inp_from_inner_inp produce the correct results # outer_inp_from_inner_inp produce the correct results
scan_node = list(updates.values())[0].owner scan_node = next(iter(updates.values())).owner
var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings() var_mappings = scan_node.op.get_oinp_iinp_iout_oout_mappings()
result = var_mappings["outer_inp_from_outer_out"] result = var_mappings["outer_inp_from_outer_out"]
...@@ -1922,7 +1922,7 @@ class TestScan: ...@@ -1922,7 +1922,7 @@ class TestScan:
_, updates = scan( _, updates = scan(
inner_fn, n_steps=10, truncate_gradient=-1, go_backwards=False inner_fn, n_steps=10, truncate_gradient=-1, go_backwards=False
) )
cost = list(updates.values())[0] cost = next(iter(updates.values()))
g_sh = grad(cost, shared_var) g_sh = grad(cost, shared_var)
fgrad = function([], g_sh) fgrad = function([], g_sh)
assert fgrad() == 1 assert fgrad() == 1
......
...@@ -270,7 +270,7 @@ class TestPushOutDot: ...@@ -270,7 +270,7 @@ class TestPushOutDot:
f = function([h0, W1, W2], o, mode=self.mode) f = function([h0, W1, W2], o, mode=self.mode)
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0] scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
assert ( assert (
len( len(
[ [
...@@ -444,9 +444,9 @@ class TestPushOutNonSeqScan: ...@@ -444,9 +444,9 @@ class TestPushOutNonSeqScan:
# Ensure that the optimization was performed correctly in f_opt # Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should # The inner function of scan should have only one output and it should
# not be the result of a Dot # not be the result of a Dot
scan_node = [ scan_node = next(
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan) node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][0] )
assert len(scan_node.op.inner_outputs) == 1 assert len(scan_node.op.inner_outputs) == 1
assert not isinstance(scan_node.op.inner_outputs[0], Dot) assert not isinstance(scan_node.op.inner_outputs[0], Dot)
...@@ -488,9 +488,9 @@ class TestPushOutNonSeqScan: ...@@ -488,9 +488,9 @@ class TestPushOutNonSeqScan:
# Ensure that the optimization was performed correctly in f_opt # Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should # The inner function of scan should have only one output and it should
# not be the result of a Dot # not be the result of a Dot
scan_node = [ scan_node = next(
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan) node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][0] )
# NOTE: WHEN INFER_SHAPE IS RE-ENABLED, BELOW THE SCAN MUST # NOTE: WHEN INFER_SHAPE IS RE-ENABLED, BELOW THE SCAN MUST
# HAVE ONLY 1 OUTPUT. # HAVE ONLY 1 OUTPUT.
assert len(scan_node.op.inner_outputs) == 2 assert len(scan_node.op.inner_outputs) == 2
...@@ -536,9 +536,9 @@ class TestPushOutNonSeqScan: ...@@ -536,9 +536,9 @@ class TestPushOutNonSeqScan:
# Ensure that the optimization was performed correctly in f_opt # Ensure that the optimization was performed correctly in f_opt
# The inner function of scan should have only one output and it should # The inner function of scan should have only one output and it should
# not be the result of a Dot # not be the result of a Dot
scan_node = [ scan_node = next(
node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan) node for node in f_opt.maker.fgraph.toposort() if isinstance(node.op, Scan)
][0] )
assert len(scan_node.op.inner_outputs) == 2 assert len(scan_node.op.inner_outputs) == 2
assert not isinstance(scan_node.op.inner_outputs[0], Dot) assert not isinstance(scan_node.op.inner_outputs[0], Dot)
...@@ -1639,7 +1639,7 @@ def test_alloc_inputs1(): ...@@ -1639,7 +1639,7 @@ def test_alloc_inputs1():
) )
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan")) f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0] scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
assert ( assert (
len( len(
[ [
...@@ -1673,7 +1673,7 @@ def test_alloc_inputs2(): ...@@ -1673,7 +1673,7 @@ def test_alloc_inputs2():
) )
f = function([h0, W1, W2], o, mode=get_default_mode().including("scan")) f = function([h0, W1, W2], o, mode=get_default_mode().including("scan"))
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0] scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
assert ( assert (
len( len(
...@@ -1709,7 +1709,7 @@ def test_alloc_inputs3(): ...@@ -1709,7 +1709,7 @@ def test_alloc_inputs3():
# TODO FIXME: This result depends on unrelated rewrites in the "fast" mode. # TODO FIXME: This result depends on unrelated rewrites in the "fast" mode.
f = function([_h0, _W1, _W2], o, mode="FAST_RUN") f = function([_h0, _W1, _W2], o, mode="FAST_RUN")
scan_node = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan)][0] scan_node = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, Scan))
assert len(scan_node.op.inner_inputs) == 1 assert len(scan_node.op.inner_inputs) == 1
......
...@@ -3274,7 +3274,7 @@ class TestLocalReduce: ...@@ -3274,7 +3274,7 @@ class TestLocalReduce:
order = f.maker.fgraph.toposort() order = f.maker.fgraph.toposort()
assert 1 == sum(isinstance(node.op, CAReduce) for node in order) assert 1 == sum(isinstance(node.op, CAReduce) for node in order)
node = [node for node in order if isinstance(node.op, CAReduce)][0] node = next(node for node in order if isinstance(node.op, CAReduce))
op = node.op op = node.op
assert isinstance(op, CAReduce) assert isinstance(op, CAReduce)
......
...@@ -75,7 +75,7 @@ def test_merge_with_weird_eq(): ...@@ -75,7 +75,7 @@ def test_merge_with_weird_eq():
MergeOptimizer().rewrite(g) MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1 assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0] node = next(iter(g.apply_nodes))
assert len(node.inputs) == 2 assert len(node.inputs) == 2
assert node.inputs[0] is node.inputs[1] assert node.inputs[0] is node.inputs[1]
...@@ -87,6 +87,6 @@ def test_merge_with_weird_eq(): ...@@ -87,6 +87,6 @@ def test_merge_with_weird_eq():
MergeOptimizer().rewrite(g) MergeOptimizer().rewrite(g)
assert len(g.apply_nodes) == 1 assert len(g.apply_nodes) == 1
node = list(g.apply_nodes)[0] node = next(iter(g.apply_nodes))
assert len(node.inputs) == 2 assert len(node.inputs) == 2
assert node.inputs[0] is node.inputs[1] assert node.inputs[0] is node.inputs[1]
...@@ -465,7 +465,7 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -465,7 +465,7 @@ class TestSpecifyShape(utt.InferShapeTester):
f(xval) f(xval)
assert isinstance( assert isinstance(
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0] next(n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape))
.inputs[0] .inputs[0]
.type, .type,
self.input_type, self.input_type,
...@@ -475,7 +475,7 @@ class TestSpecifyShape(utt.InferShapeTester): ...@@ -475,7 +475,7 @@ class TestSpecifyShape(utt.InferShapeTester):
xval = np.random.random((2, 3)).astype(config.floatX) xval = np.random.random((2, 3)).astype(config.floatX)
f = pytensor.function([x], specify_shape(x, 2, 3), mode=self.mode) f = pytensor.function([x], specify_shape(x, 2, 3), mode=self.mode)
assert isinstance( assert isinstance(
[n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape)][0] next(n for n in f.maker.fgraph.toposort() if isinstance(n.op, SpecifyShape))
.inputs[0] .inputs[0]
.type, .type,
self.input_type, self.input_type,
......
...@@ -194,7 +194,7 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -194,7 +194,7 @@ class TestIfelse(utt.OptimizationTestMixin):
f = function([c, x1, x2, y1, y2], z, mode=self.mode) f = function([c, x1, x2, y1, y2], z, mode=self.mode)
self.assertFunctionContains1(f, self.get_ifelse(2)) self.assertFunctionContains1(f, self.get_ifelse(2))
ifnode = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse)][0] ifnode = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse))
assert len(ifnode.outputs) == 2 assert len(ifnode.outputs) == 2
rng = np.random.default_rng(utt.fetch_seed()) rng = np.random.default_rng(utt.fetch_seed())
...@@ -369,7 +369,7 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -369,7 +369,7 @@ class TestIfelse(utt.OptimizationTestMixin):
z = ifelse(c, (x, x), (y, y)) z = ifelse(c, (x, x), (y, y))
f = function([c, x, y], z) f = function([c, x, y], z)
ifnode = [n for n in f.maker.fgraph.toposort() if isinstance(n.op, IfElse)][0] ifnode = next(n for n in f.maker.fgraph.toposort() if isinstance(n.op, IfElse))
assert len(ifnode.inputs) == 3 assert len(ifnode.inputs) == 3
@pytest.mark.skip(reason="Optimization temporarily disabled") @pytest.mark.skip(reason="Optimization temporarily disabled")
...@@ -382,7 +382,7 @@ class TestIfelse(utt.OptimizationTestMixin): ...@@ -382,7 +382,7 @@ class TestIfelse(utt.OptimizationTestMixin):
z = ifelse(c, (x1, x1, x1, x2, x2), (y1, y1, y2, y2, y2)) z = ifelse(c, (x1, x1, x1, x2, x2), (y1, y1, y2, y2, y2))
f = function([c, x1, x2, y1, y2], z) f = function([c, x1, x2, y1, y2], z)
ifnode = [x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse)][0] ifnode = next(x for x in f.maker.fgraph.toposort() if isinstance(x.op, IfElse))
assert len(ifnode.outputs) == 3 assert len(ifnode.outputs) == 3
@pytest.mark.skip(reason="Optimization temporarily disabled") @pytest.mark.skip(reason="Optimization temporarily disabled")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论