提交 79ff97a5 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Fix UP007 (Typealias Unions)

上级 a76172ee
......@@ -26,7 +26,7 @@ from pytensor.misc.ordered_set import OrderedSet
if TYPE_CHECKING:
from pytensor.graph.op import Op
ApplyOrOutput = Union[Apply, Literal["output"]]
ApplyOrOutput = Apply | Literal["output"]
ClientType = tuple[ApplyOrOutput, int]
......
import warnings
from collections.abc import Iterable, Mapping, Sequence
from functools import partial, singledispatch
from typing import Union, cast, overload
from typing import cast, overload
from pytensor.graph.basic import (
Apply,
......@@ -14,7 +14,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]]
ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable]
def _format_replace(replace: ReplaceTypes | None = None) -> dict[Variable, Variable]:
......
......@@ -15,7 +15,7 @@ from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable as IterableType
from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain
from typing import TYPE_CHECKING, Literal, Union, cast
from typing import TYPE_CHECKING, Literal, cast
import pytensor
from pytensor.configdefaults import config
......@@ -44,11 +44,11 @@ if TYPE_CHECKING:
_logger = logging.getLogger("pytensor.graph.rewriting.basic")
RemoveKeyType = Literal["remove"]
TransformOutputType = Union[
bool,
Sequence[Variable],
dict[Variable | Literal["remove"], Variable | Sequence[Variable]],
]
TransformOutputType = (
bool
| Sequence[Variable]
| dict[Variable | Literal["remove"], Variable | Sequence[Variable]]
)
FailureCallbackType = Callable[
[
Exception,
......
......@@ -12,7 +12,7 @@ from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import DefaultOrderedDict
RewritesType = Union[pytensor_rewriting.GraphRewriter, pytensor_rewriting.NodeRewriter]
RewritesType = pytensor_rewriting.GraphRewriter | pytensor_rewriting.NodeRewriter
class RewriteDatabase:
......
import warnings
from numbers import Number
from textwrap import dedent
from typing import Union, cast
from typing import cast
import numpy as np
......@@ -24,7 +24,7 @@ from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.variable import TensorConstant, TensorVariable
ShapeValueType = Union[None, np.integer, int, Variable]
ShapeValueType = None | np.integer | int | Variable
def register_shape_c_code(type, code, version=()):
......
import logging
import warnings
from collections.abc import Iterable
from typing import TYPE_CHECKING, Literal, Optional, Union
from typing import TYPE_CHECKING, Literal, Optional
import numpy as np
......@@ -772,7 +772,7 @@ pytensor.compile.register_deep_copy_op_c_code(
)
# Valid static type entries
ST = Union[int, None]
ST = int | None
def tensor(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论