提交 b083fb91 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Virgile Andreani

Type function arguments

上级 084fbf16
import logging
import re
import traceback as tb
from collections.abc import Iterable
from pathlib import Path
import pytensor.misc.pkl_utils
from pytensor.compile.function.pfunc import pfunc
from pytensor.compile.function.types import orig_function
from pytensor.compile.mode import Mode
from pytensor.compile.profiling import ProfileStats
from pytensor.graph import Variable
__all__ = ["types", "pfunc"]
......@@ -15,18 +20,22 @@ _logger = logging.getLogger("pytensor.compile.function")
def function_dump(
filename: str | Path,
inputs,
outputs=None,
mode=None,
updates=None,
givens=None,
no_default_updates=False,
accept_inplace=False,
name=None,
rebuild_strict=True,
allow_input_downcast=None,
profile=None,
on_unused_input=None,
inputs: Iterable[Variable],
outputs: Variable | Iterable[Variable] | dict[str, Variable] | None = None,
mode: str | Mode | None = None,
updates: Iterable[tuple[Variable, Variable]]
| dict[Variable, Variable]
| None = None,
givens: Iterable[tuple[Variable, Variable]]
| dict[Variable, Variable]
| None = None,
no_default_updates: bool = False,
accept_inplace: bool = False,
name: str | None = None,
rebuild_strict: bool = True,
allow_input_downcast: bool | None = None,
profile: bool | ProfileStats | None = None,
on_unused_input: str | None = None,
extra_tag_to_remove: str | None = None,
):
"""
......@@ -60,24 +69,21 @@ def function_dump(
`['annotations', 'replacement_of', 'aggregation_scheme', 'roles']`
"""
filename = Path(filename)
d = dict(
inputs=inputs,
outputs=outputs,
mode=mode,
updates=updates,
givens=givens,
no_default_updates=no_default_updates,
accept_inplace=accept_inplace,
name=name,
rebuild_strict=rebuild_strict,
allow_input_downcast=allow_input_downcast,
profile=profile,
on_unused_input=on_unused_input,
)
with filename.open("wb") as f:
import pytensor.misc.pkl_utils
d = {
"inputs": inputs,
"outputs": outputs,
"mode": mode,
"updates": updates,
"givens": givens,
"no_default_updates": no_default_updates,
"accept_inplace": accept_inplace,
"name": name,
"rebuild_strict": rebuild_strict,
"allow_input_downcast": allow_input_downcast,
"profile": profile,
"on_unused_input": on_unused_input,
}
with Path(filename).open("wb") as f:
pickler = pytensor.misc.pkl_utils.StripPickler(
f, protocol=-1, extra_tag_to_remove=extra_tag_to_remove
)
......@@ -85,18 +91,22 @@ def function_dump(
def function(
inputs,
outputs=None,
mode=None,
updates=None,
givens=None,
no_default_updates=False,
accept_inplace=False,
name=None,
rebuild_strict=True,
allow_input_downcast=None,
profile=None,
on_unused_input=None,
inputs: Iterable[Variable],
outputs: Variable | Iterable[Variable] | dict[str, Variable] | None = None,
mode: str | Mode | None = None,
updates: Iterable[tuple[Variable, Variable]]
| dict[Variable, Variable]
| None = None,
givens: Iterable[tuple[Variable, Variable]]
| dict[Variable, Variable]
| None = None,
no_default_updates: bool = False,
accept_inplace: bool = False,
name: str | None = None,
rebuild_strict: bool = True,
allow_input_downcast: bool | None = None,
profile: bool | ProfileStats | None = None,
on_unused_input: str | None = None,
):
"""
Return a :class:`callable object <pytensor.compile.function.types.Function>`
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论