提交 2d3cb775 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow customizable object naming in aesara.link.utils.fgraph_to_python

上级 b49bb6f8
......@@ -590,6 +590,21 @@ def compile_function_src(src, function_name, global_env=None, local_env=None):
return local_env[function_name]
def get_name_for_object(x: Any):
"""Get the name for an arbitrary object."""
if isinstance(x, Variable):
name = re.sub("[^0-9a-zA-Z]+", "_", x.name) if x.name else ""
name = name if (name.isidentifier() and not iskeyword(name)) else x.auto_name
else:
name = getattr(x, "__name__", None)
if not name or (not name.isidentifier() or iskeyword(name)):
name = type(x).__name__
return name
def fgraph_to_python(
fgraph: FunctionGraph,
op_conversion_fn: Callable,
......@@ -602,6 +617,7 @@ def fgraph_to_python(
fgraph_name: str = "fgraph_to_python",
global_env: Optional[Dict[Any, Any]] = None,
local_env: Optional[Dict[Any, Any]] = None,
get_name_for_object: Callable[[Any], str] = get_name_for_object,
**kwargs,
) -> FunctionType:
"""Convert a ``FunctionGraph`` into a regular Python function.
......@@ -634,6 +650,9 @@ def fgraph_to_python(
local_env
The local environment used when the function is constructed.
The default is ``locals()``.
get_name_for_object
A function used to provide names for the objects referenced within the
generated function.
**kwargs
The remaining keywords are passed to `python_conversion_fn`
"""
......@@ -648,15 +667,7 @@ def fgraph_to_python(
if x in obj_to_names:
return obj_to_names[x]
if isinstance(x, Variable):
name = re.sub("[^0-9a-zA-Z]+", "_", x.name) if x.name else ""
name = (
name if (name.isidentifier() and not iskeyword(name)) else x.auto_name
)
elif isinstance(x, FunctionType):
name = x.__name__
else:
name = type(x).__name__
name = get_name_for_object(x)
name_suffix = names_counter.get(name, "")
local_name = f"{name}{name_suffix}"
......
......@@ -6,7 +6,7 @@ from aesara import config
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.link.utils import fgraph_to_python
from aesara.link.utils import fgraph_to_python, get_name_for_object
from aesara.scalar.basic import Add
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar, vector
......@@ -50,6 +50,9 @@ def test_fgraph_to_python_names():
)
assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5)
obj = object()
assert get_name_for_object(obj) == type(obj).__name__
def test_fgraph_to_python_once():
"""Make sure that an output is only computed once when it's referenced multiple times."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论