提交 79ff97a5 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Fix UP007 (Typealias Unions)

上级 a76172ee
...@@ -26,7 +26,7 @@ from pytensor.misc.ordered_set import OrderedSet ...@@ -26,7 +26,7 @@ from pytensor.misc.ordered_set import OrderedSet
if TYPE_CHECKING: if TYPE_CHECKING:
from pytensor.graph.op import Op from pytensor.graph.op import Op
ApplyOrOutput = Union[Apply, Literal["output"]] ApplyOrOutput = Apply | Literal["output"]
ClientType = tuple[ApplyOrOutput, int] ClientType = tuple[ApplyOrOutput, int]
......
import warnings import warnings
from collections.abc import Iterable, Mapping, Sequence from collections.abc import Iterable, Mapping, Sequence
from functools import partial, singledispatch from functools import partial, singledispatch
from typing import Union, cast, overload from typing import cast, overload
from pytensor.graph.basic import ( from pytensor.graph.basic import (
Apply, Apply,
...@@ -14,7 +14,7 @@ from pytensor.graph.fg import FunctionGraph ...@@ -14,7 +14,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op from pytensor.graph.op import Op
ReplaceTypes = Union[Iterable[tuple[Variable, Variable]], dict[Variable, Variable]] ReplaceTypes = Iterable[tuple[Variable, Variable]] | dict[Variable, Variable]
def _format_replace(replace: ReplaceTypes | None = None) -> dict[Variable, Variable]: def _format_replace(replace: ReplaceTypes | None = None) -> dict[Variable, Variable]:
......
...@@ -15,7 +15,7 @@ from collections.abc import Callable, Iterable, Sequence ...@@ -15,7 +15,7 @@ from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable as IterableType from collections.abc import Iterable as IterableType
from functools import _compose_mro, partial, reduce # type: ignore from functools import _compose_mro, partial, reduce # type: ignore
from itertools import chain from itertools import chain
from typing import TYPE_CHECKING, Literal, Union, cast from typing import TYPE_CHECKING, Literal, cast
import pytensor import pytensor
from pytensor.configdefaults import config from pytensor.configdefaults import config
...@@ -44,11 +44,11 @@ if TYPE_CHECKING: ...@@ -44,11 +44,11 @@ if TYPE_CHECKING:
_logger = logging.getLogger("pytensor.graph.rewriting.basic") _logger = logging.getLogger("pytensor.graph.rewriting.basic")
RemoveKeyType = Literal["remove"] RemoveKeyType = Literal["remove"]
TransformOutputType = Union[ TransformOutputType = (
bool, bool
Sequence[Variable], | Sequence[Variable]
dict[Variable | Literal["remove"], Variable | Sequence[Variable]], | dict[Variable | Literal["remove"], Variable | Sequence[Variable]]
] )
FailureCallbackType = Callable[ FailureCallbackType = Callable[
[ [
Exception, Exception,
......
...@@ -12,7 +12,7 @@ from pytensor.misc.ordered_set import OrderedSet ...@@ -12,7 +12,7 @@ from pytensor.misc.ordered_set import OrderedSet
from pytensor.utils import DefaultOrderedDict from pytensor.utils import DefaultOrderedDict
RewritesType = Union[pytensor_rewriting.GraphRewriter, pytensor_rewriting.NodeRewriter] RewritesType = pytensor_rewriting.GraphRewriter | pytensor_rewriting.NodeRewriter
class RewriteDatabase: class RewriteDatabase:
......
import warnings import warnings
from numbers import Number from numbers import Number
from textwrap import dedent from textwrap import dedent
from typing import Union, cast from typing import cast
import numpy as np import numpy as np
...@@ -24,7 +24,7 @@ from pytensor.tensor.type_other import NoneConst ...@@ -24,7 +24,7 @@ from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
ShapeValueType = Union[None, np.integer, int, Variable] ShapeValueType = None | np.integer | int | Variable
def register_shape_c_code(type, code, version=()): def register_shape_c_code(type, code, version=()):
......
import logging import logging
import warnings import warnings
from collections.abc import Iterable from collections.abc import Iterable
from typing import TYPE_CHECKING, Literal, Optional, Union from typing import TYPE_CHECKING, Literal, Optional
import numpy as np import numpy as np
...@@ -772,7 +772,7 @@ pytensor.compile.register_deep_copy_op_c_code( ...@@ -772,7 +772,7 @@ pytensor.compile.register_deep_copy_op_c_code(
) )
# Valid static type entries # Valid static type entries
ST = Union[int, None] ST = int | None
def tensor( def tensor(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论