提交 4d5aca03 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Type function.copy method parameters

上级 99589bb9
...@@ -14,6 +14,7 @@ import pytensor ...@@ -14,6 +14,7 @@ import pytensor
import pytensor.compile.profiling import pytensor.compile.profiling
from pytensor.compile.io import In, SymbolicInput, SymbolicOutput from pytensor.compile.io import In, SymbolicInput, SymbolicOutput
from pytensor.compile.ops import deep_copy_op, view_op from pytensor.compile.ops import deep_copy_op, view_op
from pytensor.compile.profiling import ProfileStats
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import ( from pytensor.graph.basic import (
Constant, Constant,
...@@ -553,11 +554,11 @@ class Function: ...@@ -553,11 +554,11 @@ class Function:
def copy( def copy(
self, self,
share_memory=False, share_memory: bool = False,
swap=None, swap: dict | None = None,
delete_updates=False, delete_updates: bool = False,
name=None, name: str | None = None,
profile=None, profile: bool | str | ProfileStats | None = None,
): ):
""" """
Copy this function. Copied function will have separated maker and Copy this function. Copied function will have separated maker and
...@@ -584,7 +585,7 @@ class Function: ...@@ -584,7 +585,7 @@ class Function:
If provided, will be the name of the new If provided, will be the name of the new
Function. Otherwise, it will be old + " copy" Function. Otherwise, it will be old + " copy"
profile : profile : bool | str | ProfileStats | None
as pytensor.function profile parameter as pytensor.function profile parameter
Returns Returns
...@@ -723,14 +724,8 @@ class Function: ...@@ -723,14 +724,8 @@ class Function:
# reinitialize new maker and create new function # reinitialize new maker and create new function
if profile is None: if profile is None:
profile = config.profile or config.print_global_stats profile = config.profile or config.print_global_stats
# profile -> True or False
if profile is True: if profile is True:
if name: profile = pytensor.compile.profiling.ProfileStats(message=name)
message = name
else:
message = str(profile.message) + " copy"
profile = pytensor.compile.profiling.ProfileStats(message=message)
# profile -> object
elif isinstance(profile, str): elif isinstance(profile, str):
profile = pytensor.compile.profiling.ProfileStats(message=profile) profile = pytensor.compile.profiling.ProfileStats(message=profile)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论