提交 dd8895df authored 作者: ferres's avatar ferres 提交者: Maxim Kochurov

mypy: fix graph.py

上级 f62401a0
...@@ -4,7 +4,7 @@ import time ...@@ -4,7 +4,7 @@ import time
import warnings import warnings
from collections.abc import Callable, Mapping, MutableSequence, Sequence from collections.abc import Callable, Mapping, MutableSequence, Sequence
from functools import partial, reduce from functools import partial, reduce
from typing import TYPE_CHECKING, Literal, TypeVar, Union from typing import TYPE_CHECKING, Literal, TypeVar, Union, overload
import numpy as np import numpy as np
...@@ -414,6 +414,32 @@ def Lop( ...@@ -414,6 +414,32 @@ def Lop(
return as_list_or_tuple(using_list, using_tuple, ret) return as_list_or_tuple(using_list, using_tuple, ret)
@overload
def grad(
cost: Variable | None,
wrt: Variable | Sequence[Variable],
consider_constant: Sequence[Variable] | None = ...,
disconnected_inputs: Literal["ignore", "warn", "raise"] = ...,
add_names: bool = ...,
known_grads: Mapping[Variable, Variable] | None = ...,
return_disconnected: Literal["zero", "disconnected"] = ...,
null_gradients: Literal["raise", "return"] = ...,
) -> Variable | None | Sequence[Variable]: ...
@overload
def grad(
cost: Variable | None,
wrt: Variable | Sequence[Variable],
consider_constant: Sequence[Variable] | None = ...,
disconnected_inputs: Literal["ignore", "warn", "raise"] = ...,
add_names: bool = ...,
known_grads: Mapping[Variable, Variable] | None = ...,
return_disconnected: Literal["none"] = ...,
null_gradients: Literal["raise", "return"] = ...,
) -> Variable | None | Sequence[Variable | None]: ...
def grad( def grad(
cost: Variable | None, cost: Variable | None,
wrt: Variable | Sequence[Variable], wrt: Variable | Sequence[Variable],
...@@ -423,7 +449,7 @@ def grad( ...@@ -423,7 +449,7 @@ def grad(
known_grads: Mapping[Variable, Variable] | None = None, known_grads: Mapping[Variable, Variable] | None = None,
return_disconnected: Literal["none", "zero", "disconnected"] = "zero", return_disconnected: Literal["none", "zero", "disconnected"] = "zero",
null_gradients: Literal["raise", "return"] = "raise", null_gradients: Literal["raise", "return"] = "raise",
) -> Variable | None | Sequence[Variable | None]: ) -> Variable | None | Sequence[Variable | None] | Sequence[Variable]:
""" """
Return symbolic gradients of one cost with respect to one or more variables. Return symbolic gradients of one cost with respect to one or more variables.
......
...@@ -1313,8 +1313,9 @@ def clone_get_equiv( ...@@ -1313,8 +1313,9 @@ def clone_get_equiv(
outputs: Reversible[Variable], outputs: Reversible[Variable],
copy_inputs: bool = True, copy_inputs: bool = True,
copy_orphans: bool = True, copy_orphans: bool = True,
memo: dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] memo: (
| None = None, dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]] | None
) = None,
clone_inner_graphs: bool = False, clone_inner_graphs: bool = False,
**kwargs, **kwargs,
) -> dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]: ) -> dict[Union[Apply, Variable, "Op"], Union[Apply, Variable, "Op"]]:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论