提交 3fcf6369 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Add pyupgrade to pre-commit and apply it

上级 f254492b
......@@ -19,6 +19,11 @@ repos:
pytensor/tensor/var\.py|
)$
- id: check-merge-conflict
- repo: https://github.com/asottile/pyupgrade
rev: v3.3.1
hooks:
- id: pyupgrade
args: [--py38-plus]
- repo: https://github.com/psf/black
rev: 22.10.0
hooks:
......
......@@ -5,9 +5,7 @@ WRITEME
import logging
import warnings
from typing import Optional, Tuple, Union
from typing_extensions import Literal
from typing import Literal, Optional, Tuple, Union
from pytensor.compile.function.types import Supervisor
from pytensor.configdefaults import config
......
......@@ -8,6 +8,7 @@ from typing import (
Callable,
Dict,
List,
Literal,
Mapping,
MutableSequence,
Optional,
......@@ -18,7 +19,6 @@ from typing import (
)
import numpy as np
from typing_extensions import Literal
import pytensor
from pytensor.compile.ops import ViewOp
......
......@@ -7,6 +7,7 @@ from typing import (
Dict,
Iterable,
List,
Literal,
Optional,
Sequence,
Set,
......@@ -15,8 +16,6 @@ from typing import (
cast,
)
from typing_extensions import Literal
import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, AtomicVariable, Variable, applys_between
......
......@@ -9,6 +9,7 @@ from typing import (
Dict,
List,
Optional,
Protocol,
Sequence,
Tuple,
TypeVar,
......@@ -16,8 +17,6 @@ from typing import (
cast,
)
from typing_extensions import Protocol
import pytensor
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, NoParams, Variable
......
......@@ -15,9 +15,7 @@ from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain
from typing import TYPE_CHECKING, Callable, Dict
from typing import Iterable as IterableType
from typing import List, Optional, Sequence, Tuple, Union, cast
from typing_extensions import Literal
from typing import List, Literal, Optional, Sequence, Tuple, Union, cast
import pytensor
from pytensor.configdefaults import config
......@@ -1185,7 +1183,7 @@ class OpToRewriterTracker:
matches.extend(match)
return matches
@functools.lru_cache()
@functools.lru_cache
def get_trackers(self, op: Op) -> List[NodeRewriter]:
"""Get all the rewrites applicable to `op`."""
return (
......
......@@ -19,7 +19,17 @@ import textwrap
import time
import warnings
from io import BytesIO, StringIO
from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Set, Tuple, cast
from typing import (
TYPE_CHECKING,
Callable,
Dict,
List,
Optional,
Protocol,
Set,
Tuple,
cast,
)
import numpy as np
from setuptools._distutils.sysconfig import (
......@@ -28,7 +38,6 @@ from setuptools._distutils.sysconfig import (
get_python_inc,
get_python_lib,
)
from typing_extensions import Protocol
# we will abuse the lockfile mechanism when reading and writing the registry
from pytensor.compile.compilelock import lock_ctx
......
from __future__ import annotations
from typing import Any, List, Optional, Tuple
from typing import Any
import numba
import numpy as np
......@@ -14,8 +14,8 @@ from numba.np import arrayobj
def compute_itershape(
ctx: BaseContext,
builder: ir.IRBuilder,
in_shapes: Tuple[ir.Instruction, ...],
broadcast_pattern: Tuple[Tuple[bool, ...], ...],
in_shapes: tuple[ir.Instruction, ...],
broadcast_pattern: tuple[tuple[bool, ...], ...],
):
one = ir.IntType(64)(1)
ndim = len(in_shapes[0])
......@@ -63,12 +63,12 @@ def compute_itershape(
def make_outputs(
ctx: numba.core.base.BaseContext,
builder: ir.IRBuilder,
iter_shape: Tuple[ir.Instruction, ...],
out_bc: Tuple[Tuple[bool, ...], ...],
dtypes: Tuple[Any, ...],
inplace: Tuple[Tuple[int, int], ...],
inputs: Tuple[Any, ...],
input_types: Tuple[Any, ...],
iter_shape: tuple[ir.Instruction, ...],
out_bc: tuple[tuple[bool, ...], ...],
dtypes: tuple[Any, ...],
inplace: tuple[tuple[int, int], ...],
inputs: tuple[Any, ...],
input_types: tuple[Any, ...],
):
arrays = []
ar_types: list[types.Array] = []
......@@ -106,13 +106,13 @@ def make_loop_call(
builder: ir.IRBuilder,
scalar_func: Any,
scalar_signature: types.FunctionType,
iter_shape: Tuple[ir.Instruction, ...],
inputs: Tuple[ir.Instruction, ...],
outputs: Tuple[ir.Instruction, ...],
input_bc: Tuple[Tuple[bool, ...], ...],
output_bc: Tuple[Tuple[bool, ...], ...],
input_types: Tuple[Any, ...],
output_types: Tuple[Any, ...],
iter_shape: tuple[ir.Instruction, ...],
inputs: tuple[ir.Instruction, ...],
outputs: tuple[ir.Instruction, ...],
input_bc: tuple[tuple[bool, ...], ...],
output_bc: tuple[tuple[bool, ...], ...],
input_types: tuple[Any, ...],
output_types: tuple[Any, ...],
):
safe = (False, False)
n_outputs = len(outputs)
......@@ -150,9 +150,7 @@ def make_loop_call(
# This part corresponds to opening the loops
loop_stack = []
loops = []
output_accumulator: List[Tuple[Optional[Any], Optional[int]]] = [
(None, None)
] * n_outputs
output_accumulator: list[tuple[Any | None, int | None]] = [(None, None)] * n_outputs
for dim, length in enumerate(iter_shape):
# Find outputs that only have accumulations left
for output in range(n_outputs):
......
......@@ -9,10 +9,20 @@ from contextlib import contextmanager
from copy import copy
from functools import reduce, singledispatch
from io import StringIO
from typing import Any, Callable, Dict, List, Optional, Sequence, TextIO, Tuple, Union
from typing import (
Any,
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
TextIO,
Tuple,
Union,
)
import numpy as np
from typing_extensions import Literal
from pytensor.compile import Function, SharedVariable
from pytensor.compile.io import In, Out
......
from typing import Iterable, Optional, Union
from typing import Iterable, Literal, Optional, Union
import numpy as np
import scipy.sparse
from typing_extensions import Literal
import pytensor
from pytensor import scalar as aes
......
......@@ -143,7 +143,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
# Check that Dimshuffle does not affect support dims
supp_dims = set(range(rv.ndim - rv_op.ndim_supp, rv.ndim))
shuffled_dims = {dim for i, dim in enumerate(ds_op.shuffle) if dim != i}
augmented_dims = set(d - rv_op.ndim_supp for d in ds_op.augment)
augmented_dims = {d - rv_op.ndim_supp for d in ds_op.augment}
if (shuffled_dims | augmented_dims) & supp_dims:
return False
......
......@@ -2,10 +2,9 @@ from collections.abc import Sequence
from functools import wraps
from itertools import zip_longest
from types import ModuleType
from typing import TYPE_CHECKING, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional, Union
import numpy as np
from typing_extensions import Literal
from pytensor.compile.sharedvalue import shared
from pytensor.graph.basic import Constant, Variable
......
import logging
import warnings
from typing import TYPE_CHECKING, Union
from typing import TYPE_CHECKING, Literal, Union
import numpy as np
import scipy.linalg
from typing_extensions import Literal
import pytensor
import pytensor.tensor as pt
......
import logging
import warnings
from typing import TYPE_CHECKING, Iterable, Optional, Tuple, Union
from typing import TYPE_CHECKING, Iterable, Literal, Optional, Tuple, Union
import numpy as np
from typing_extensions import Literal
import pytensor
from pytensor import scalar as aes
......
......@@ -435,9 +435,9 @@ def test_inner_graph_optimized():
# Disable scan pushout, in which case the whole scan is replaced by an Elemwise
f = function([xs], seq, mode=get_mode("NUMBA").excluding("scan_pushout"))
(scan_node,) = [
(scan_node,) = (
node for node in f.maker.fgraph.apply_nodes if isinstance(node.op, Scan)
]
)
inner_scan_nodes = scan_node.op.fgraph.apply_nodes
assert len(inner_scan_nodes) == 1
(inner_scan_node,) = scan_node.op.fgraph.apply_nodes
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论