提交 9af376f6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Extract and generalize the unique name generator in fgraph_to_python

上级 bab68456
......@@ -605,6 +605,38 @@ def get_name_for_object(x: Any):
return name
def unique_name_generator(
external_names: Optional[List[str]] = None, suffix_sep: str = ""
) -> Callable:
"""Create a function that generates unique names."""
if external_names is None:
external_names = []
def unique_name(x, force_unique=False):
if not force_unique and x in unique_name.obj_to_names:
return unique_name.obj_to_names[x]
name = get_name_for_object(x)
name_suffix = unique_name.names_counter.get(name, "")
if name_suffix:
local_name = f"{name}{suffix_sep}{name_suffix}"
unique_name.names_counter.update((name,))
else:
local_name = name
unique_name.names_counter.update((local_name,))
unique_name.obj_to_names[x] = local_name
return local_name
unique_name.names_counter = Counter(external_names)
unique_name.obj_to_names = {}
return unique_name
def fgraph_to_python(
fgraph: FunctionGraph,
op_conversion_fn: Callable,
......@@ -663,19 +695,7 @@ def fgraph_to_python(
fgraph, order, input_storage, output_storage, storage_map
)
def unique_name(x, names_counter=Counter([fgraph_name]), obj_to_names={}):
if x in obj_to_names:
return obj_to_names[x]
name = get_name_for_object(x)
name_suffix = names_counter.get(name, "")
local_name = f"{name}{name_suffix}"
names_counter.update((name,))
obj_to_names[x] = local_name
return local_name
unique_name = unique_name_generator([fgraph_name])
if global_env is None:
global_env = {}
......
......@@ -6,7 +6,11 @@ from aesara import config
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.link.utils import fgraph_to_python, get_name_for_object
from aesara.link.utils import (
fgraph_to_python,
get_name_for_object,
unique_name_generator,
)
from aesara.scalar.basic import Add
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar, vector
......@@ -102,3 +106,49 @@ def test_fgraph_to_python_once():
assert len(res) == 2
assert op1.called == 2
assert op2.called == 2
def test_unique_name_generator():
unique_names = unique_name_generator(["blah"], suffix_sep="_")
x = vector("blah")
x_name = unique_names(x)
assert x_name == "blah_1"
y = vector("blah_1")
y_name = unique_names(y)
assert y_name == "blah_1_1"
# Make sure that the old name associations are still good
x_name = unique_names(x)
assert x_name == "blah_1"
y_name = unique_names(y)
assert y_name == "blah_1_1"
# Try a name that overlaps with the original name
z = vector("blah")
z_name = unique_names(z)
assert z_name == "blah_2"
# Try a name that overlaps with an extended name
w = vector("blah_1")
w_name = unique_names(w)
assert w_name == "blah_1_2"
q = vector()
q_name_1 = unique_names(q)
q_name_2 = unique_names(q)
assert q_name_1 == q_name_2 == q.auto_name
unique_names = unique_name_generator()
r = vector()
r_name_1 = unique_names(r)
r_name_2 = unique_names(r, force_unique=True)
assert r_name_1 != r_name_2
r_name_3 = unique_names(r)
assert r_name_2 == r_name_3
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论