提交 dd8895df authored 作者: ferres's avatar ferres 提交者: Maxim Kochurov

mypy: fix graph.py

上级 f62401a0
......@@ -4,7 +4,7 @@ import time
import warnings
from collections.abc import Callable, Mapping, MutableSequence, Sequence
from functools import partial, reduce
from typing import TYPE_CHECKING, Literal, TypeVar, Union
from typing import TYPE_CHECKING, Literal, TypeVar, Union, overload
import numpy as np
......@@ -414,6 +414,32 @@ def Lop(
return as_list_or_tuple(using_list, using_tuple, ret)
@overload
def grad(
cost: Variable | None,
wrt: Variable | Sequence[Variable],
consider_constant: Sequence[Variable] | None = ...,
disconnected_inputs: Literal["ignore", "warn", "raise"] = ...,
add_names: bool = ...,
known_grads: Mapping[Variable, Variable] | None = ...,
return_disconnected: Literal["zero", "disconnected"] = ...,
null_gradients: Literal["raise", "return"] = ...,
) -> Variable | None | Sequence[Variable]: ...
@overload
def grad(
cost: Variable | None,
wrt: Variable | Sequence[Variable],
consider_constant: Sequence[Variable] | None = ...,
disconnected_inputs: Literal["ignore", "warn", "raise"] = ...,
add_names: bool = ...,
known_grads: Mapping[Variable, Variable] | None = ...,
return_disconnected: Literal["none"] = ...,
null_gradients: Literal["raise", "return"] = ...,
) -> Variable | None | Sequence[Variable | None]: ...
def grad(
cost: Variable | None,
wrt: Variable | Sequence[Variable],
......@@ -423,7 +449,7 @@ def grad(
known_grads: Mapping[Variable, Variable] | None = None,
return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
null_gradients: Literal["raise", "return"] = "raise",
) -> Variable | None | Sequence[Variable | None]:
) -> Variable | None | Sequence[Variable | None] | Sequence[Variable]:
"""
Return symbolic gradients of one cost with respect to one or more variables.
......
......@@ -1313,8 +1313,9 @@ def clone_get_equiv(
outputs: Reversible[Variable],
copy_inputs: bool = True,
copy_orphans: bool = True,
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]
| None = None,
memo: (
dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] | None
) = None,
clone_inner_graphs: bool = False,
**kwargs,
) -> dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论