提交 d8c5af89 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Upgrade to ruff 0.2.0 and fix RUF017

* --show-source -> --output-format=full * renaming of some config options * removing --line-length because it is already in the pyproject file * taking care of some list quadratic summations
上级 4801eff3
......@@ -20,12 +20,11 @@ repos:
)$
- id: check-merge-conflict
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.14
rev: v0.2.0
hooks:
- id: ruff
args: ["--fix", "--show-source"]
args: ["--fix", "--output-format=full"]
- id: ruff-format
args: ["--line-length=88"]
- repo: https://github.com/pre-commit/mirrors-mypy
rev: v1.8.0
hooks:
......
......@@ -121,23 +121,19 @@ tag_prefix = "rel-"
addopts = "--durations=50"
testpaths = "tests/"
[tool.pylint]
max-line-length = 88
[tool.pylint.messages_control]
disable = ["C0330", "C0326"]
[tool.ruff]
line-length = 88
exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"]
[tool.ruff.lint]
select = ["C", "E", "F", "I", "UP", "W", "RUF"]
ignore = ["C408", "C901", "E501", "E741", "RUF012"]
exclude = ["doc/", "pytensor/_version.py", "bin/pytensor_cache.py"]
[tool.ruff.isort]
[tool.ruff.lint.isort]
lines-after-imports = 2
[tool.ruff.per-file-ignores]
[tool.ruff.lint.per-file-ignores]
# TODO: Get rid of these:
"**/__init__.py" = ["F401", "E402", "F403"]
"pytensor/tensor/linalg.py" = ["F403"]
......
......@@ -233,10 +233,11 @@ def fast_inplace_check(fgraph, inputs):
"""
Supervisor = pytensor.compile.function.types.Supervisor
protected_inputs = [
protected_inputs = list(
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
]
protected_inputs = sum(protected_inputs, []) # flatten the list
)
)
protected_inputs.extend(fgraph.outputs)
inputs = [
......
......@@ -4096,9 +4096,10 @@ class ScalarInnerGraphOp(ScalarOp, HasInnerGraph):
return tuple(rval)
def c_header_dirs(self, **kwargs):
rval = sum(
(subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()),
[],
rval = list(
chain.from_iterable(
subnode.op.c_header_dirs(**kwargs) for subnode in self.fgraph.toposort()
)
)
return rval
......
......@@ -1067,7 +1067,9 @@ class ScanInplaceOptimizer(GraphRewriter):
if client.op.destroy_map:
# This flattens the content of destroy_map.values()
# which is a list of lists
inplace_inp_indices = sum(client.op.destroy_map.values(), [])
inplace_inp_indices = chain.from_iterable(
client.op.destroy_map.values()
)
inplace_inps = [client.inputs[i] for i in inplace_inp_indices]
if original_node.inputs[inp_idx] in inplace_inps:
......@@ -1860,8 +1862,8 @@ class ScanMerge(GraphRewriter):
# Clone the inner graph of each node independently
for idx, nd in enumerate(nodes):
# concatenate all inner_ins and inner_outs of nd
flat_inner_ins = sum(inner_ins[idx], [])
flat_inner_outs = sum(inner_outs[idx], [])
flat_inner_ins = list(chain.from_iterable(inner_ins[idx]))
flat_inner_outs = list(chain.from_iterable(inner_outs[idx]))
# clone
flat_inner_ins, flat_inner_outs = reconstruct_graph(
flat_inner_ins, flat_inner_outs
......
......@@ -765,8 +765,8 @@ class ScanArgs:
def inner_inputs(self):
return (
self.inner_in_seqs
+ sum(self.inner_in_mit_mot, [])
+ sum(self.inner_in_mit_sot, [])
+ list(chain.from_iterable(self.inner_in_mit_mot))
+ list(chain.from_iterable(self.inner_in_mit_sot))
+ self.inner_in_sit_sot
+ self.inner_in_shared
+ self.inner_in_non_seqs
......@@ -788,7 +788,7 @@ class ScanArgs:
@property
def inner_outputs(self):
return (
sum(self.inner_out_mit_mot, [])
list(chain.from_iterable(self.inner_out_mit_mot))
+ self.inner_out_mit_sot
+ self.inner_out_sit_sot
+ self.inner_out_nit_sot
......
import itertools
import sys
from collections import defaultdict, deque
from collections.abc import Generator
......@@ -144,12 +145,12 @@ class InplaceElemwiseOptimizer(GraphRewriter):
else:
update_outs = []
protected_inputs = [
f.protected
for f in fgraph._features
if isinstance(f, pytensor.compile.function.types.Supervisor)
]
protected_inputs = sum(protected_inputs, []) # flatten the list
Supervisor = pytensor.compile.function.types.Supervisor
protected_inputs = list(
itertools.chain.from_iterable(
f.protected for f in fgraph._features if isinstance(f, Supervisor)
)
)
protected_inputs.extend(fgraph.outputs)
for node in list(io_toposort(fgraph.inputs, fgraph.outputs)):
op = node.op
......
import itertools
import pickle
import numpy as np
......@@ -574,7 +575,8 @@ class TestFunctionGraph:
assert fg.outputs == [op1_out]
assert op3_out not in fg.clients
assert not any(
op3_out.owner in clients for clients in sum(fg.clients.values(), [])
op3_out.owner in clients
for clients in itertools.chain.from_iterable(fg.clients.values())
)
# Remove an input
......@@ -585,7 +587,8 @@ class TestFunctionGraph:
assert fg.inputs == [var2]
assert fg.outputs == []
assert not any(
op1_out.owner in clients for clients in sum(fg.clients.values(), [])
op1_out.owner in clients
for clients in itertools.chain.from_iterable(fg.clients.values())
)
def test_remove_duplicates(self):
......@@ -622,10 +625,12 @@ class TestFunctionGraph:
assert not fg.apply_nodes
assert op1_out not in fg.clients
assert not any(
op1_out.owner in clients for clients in sum(fg.clients.values(), [])
op1_out.owner in clients
for clients in itertools.chain.from_iterable(fg.clients.values())
)
assert not any(
op3_out.owner in clients for clients in sum(fg.clients.values(), [])
op3_out.owner in clients
for clients in itertools.chain.from_iterable(fg.clients.values())
)
def test_remove_node_multi_out(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论