提交 9af376f6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Extract and generalize the unique name generator in fgraph_to_python

上级 bab68456
...@@ -605,6 +605,38 @@ def get_name_for_object(x: Any): ...@@ -605,6 +605,38 @@ def get_name_for_object(x: Any):
return name return name
def unique_name_generator(
external_names: Optional[List[str]] = None, suffix_sep: str = ""
) -> Callable:
"""Create a function that generates unique names."""
if external_names is None:
external_names = []
def unique_name(x, force_unique=False):
if not force_unique and x in unique_name.obj_to_names:
return unique_name.obj_to_names[x]
name = get_name_for_object(x)
name_suffix = unique_name.names_counter.get(name, "")
if name_suffix:
local_name = f"{name}{suffix_sep}{name_suffix}"
unique_name.names_counter.update((name,))
else:
local_name = name
unique_name.names_counter.update((local_name,))
unique_name.obj_to_names[x] = local_name
return local_name
unique_name.names_counter = Counter(external_names)
unique_name.obj_to_names = {}
return unique_name
def fgraph_to_python( def fgraph_to_python(
fgraph: FunctionGraph, fgraph: FunctionGraph,
op_conversion_fn: Callable, op_conversion_fn: Callable,
...@@ -663,19 +695,7 @@ def fgraph_to_python( ...@@ -663,19 +695,7 @@ def fgraph_to_python(
fgraph, order, input_storage, output_storage, storage_map fgraph, order, input_storage, output_storage, storage_map
) )
def unique_name(x, names_counter=Counter([fgraph_name]), obj_to_names={}): unique_name = unique_name_generator([fgraph_name])
if x in obj_to_names:
return obj_to_names[x]
name = get_name_for_object(x)
name_suffix = names_counter.get(name, "")
local_name = f"{name}{name_suffix}"
names_counter.update((name,))
obj_to_names[x] = local_name
return local_name
if global_env is None: if global_env is None:
global_env = {} global_env = {}
......
...@@ -6,7 +6,11 @@ from aesara import config ...@@ -6,7 +6,11 @@ from aesara import config
from aesara.graph.basic import Apply from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.link.utils import fgraph_to_python, get_name_for_object from aesara.link.utils import (
fgraph_to_python,
get_name_for_object,
unique_name_generator,
)
from aesara.scalar.basic import Add from aesara.scalar.basic import Add
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar, vector from aesara.tensor.type import scalar, vector
...@@ -102,3 +106,49 @@ def test_fgraph_to_python_once(): ...@@ -102,3 +106,49 @@ def test_fgraph_to_python_once():
assert len(res) == 2 assert len(res) == 2
assert op1.called == 2 assert op1.called == 2
assert op2.called == 2 assert op2.called == 2
def test_unique_name_generator():
unique_names = unique_name_generator(["blah"], suffix_sep="_")
x = vector("blah")
x_name = unique_names(x)
assert x_name == "blah_1"
y = vector("blah_1")
y_name = unique_names(y)
assert y_name == "blah_1_1"
# Make sure that the old name associations are still good
x_name = unique_names(x)
assert x_name == "blah_1"
y_name = unique_names(y)
assert y_name == "blah_1_1"
# Try a name that overlaps with the original name
z = vector("blah")
z_name = unique_names(z)
assert z_name == "blah_2"
# Try a name that overlaps with an extended name
w = vector("blah_1")
w_name = unique_names(w)
assert w_name == "blah_1_2"
q = vector()
q_name_1 = unique_names(q)
q_name_2 = unique_names(q)
assert q_name_1 == q_name_2 == q.auto_name
unique_names = unique_name_generator()
r = vector()
r_name_1 = unique_names(r)
r_name_2 = unique_names(r, force_unique=True)
assert r_name_1 != r_name_2
r_name_3 = unique_names(r)
assert r_name_2 == r_name_3
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论