提交 2d3cb775 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Allow customizable object naming in aesara.link.utils.fgraph_to_python

上级 b49bb6f8
...@@ -590,6 +590,21 @@ def compile_function_src(src, function_name, global_env=None, local_env=None): ...@@ -590,6 +590,21 @@ def compile_function_src(src, function_name, global_env=None, local_env=None):
return local_env[function_name] return local_env[function_name]
def get_name_for_object(x: Any):
"""Get the name for an arbitrary object."""
if isinstance(x, Variable):
name = re.sub("[^0-9a-zA-Z]+", "_", x.name) if x.name else ""
name = name if (name.isidentifier() and not iskeyword(name)) else x.auto_name
else:
name = getattr(x, "__name__", None)
if not name or (not name.isidentifier() or iskeyword(name)):
name = type(x).__name__
return name
def fgraph_to_python( def fgraph_to_python(
fgraph: FunctionGraph, fgraph: FunctionGraph,
op_conversion_fn: Callable, op_conversion_fn: Callable,
...@@ -602,6 +617,7 @@ def fgraph_to_python( ...@@ -602,6 +617,7 @@ def fgraph_to_python(
fgraph_name: str = "fgraph_to_python", fgraph_name: str = "fgraph_to_python",
global_env: Optional[Dict[Any, Any]] = None, global_env: Optional[Dict[Any, Any]] = None,
local_env: Optional[Dict[Any, Any]] = None, local_env: Optional[Dict[Any, Any]] = None,
get_name_for_object: Callable[[Any], str] = get_name_for_object,
**kwargs, **kwargs,
) -> FunctionType: ) -> FunctionType:
"""Convert a ``FunctionGraph`` into a regular Python function. """Convert a ``FunctionGraph`` into a regular Python function.
...@@ -634,6 +650,9 @@ def fgraph_to_python( ...@@ -634,6 +650,9 @@ def fgraph_to_python(
local_env local_env
The local environment used when the function is constructed. The local environment used when the function is constructed.
The default is ``locals()``. The default is ``locals()``.
get_name_for_object
A function used to provide names for the objects referenced within the
generated function.
**kwargs **kwargs
The remaining keywords are passed to `python_conversion_fn` The remaining keywords are passed to `python_conversion_fn`
""" """
...@@ -648,15 +667,7 @@ def fgraph_to_python( ...@@ -648,15 +667,7 @@ def fgraph_to_python(
if x in obj_to_names: if x in obj_to_names:
return obj_to_names[x] return obj_to_names[x]
if isinstance(x, Variable): name = get_name_for_object(x)
name = re.sub("[^0-9a-zA-Z]+", "_", x.name) if x.name else ""
name = (
name if (name.isidentifier() and not iskeyword(name)) else x.auto_name
)
elif isinstance(x, FunctionType):
name = x.__name__
else:
name = type(x).__name__
name_suffix = names_counter.get(name, "") name_suffix = names_counter.get(name, "")
local_name = f"{name}{name_suffix}" local_name = f"{name}{name_suffix}"
......
...@@ -6,7 +6,7 @@ from aesara import config ...@@ -6,7 +6,7 @@ from aesara import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.link.utils import fgraph_to_python from aesara.link.utils import fgraph_to_python, get_name_for_object
from aesara.scalar.basic import Add from aesara.scalar.basic import Add
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar, vector from aesara.tensor.type import scalar, vector
...@@ -50,6 +50,9 @@ def test_fgraph_to_python_names(): ...@@ -50,6 +50,9 @@ def test_fgraph_to_python_names():
) )
assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5) assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5)
obj = object()
assert get_name_for_object(obj) == type(obj).__name__
def test_fgraph_to_python_once(): def test_fgraph_to_python_once():
"""Make sure that an output is only computed once when it's referenced multiple times.""" """Make sure that an output is only computed once when it's referenced multiple times."""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论