提交 205b7a84 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Stop using auto_name for transpilation

Using the `auto_name` values will result in cache misses when caching is based on the generated source code, so we're not going to use it.
上级 071d3cae
......@@ -393,10 +393,14 @@ def numba_funcify_FunctionGraph(
def create_index_func(node, objmode=False):
"""Create a Python function that assembles and uses an index on an array."""
unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)
def convert_indices(indices, entry):
if indices and isinstance(entry, Type):
rval = indices.pop(0)
return rval.auto_name
return unique_names(rval)
elif isinstance(entry, slice):
return (
f"slice({convert_indices(indices, entry.start)}, "
......@@ -413,10 +417,6 @@ def create_index_func(node, objmode=False):
)
index_start_idx = 1 + int(set_or_inc)
unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(node.op, "idx_list", None)
......
......@@ -54,7 +54,7 @@ def numba_funcify_Scan(op, node, **kwargs):
p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs
p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot
input_names = [f"{n.auto_name}_{i}" for i, n in enumerate(node.inputs[1:])]
input_names = [f"outer_in_{i}" for i, n in enumerate(node.inputs[1:])]
outer_in_seqs_names = input_names[:n_seqs]
outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot]
outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot]
......
......@@ -613,8 +613,8 @@ def compile_function_src(
def get_name_for_object(x: Any) -> str:
"""Get the name for an arbitrary object."""
if isinstance(x, Variable):
name = re.sub("[^0-9a-zA-Z]+", "_", x.name) if x.name else ""
if isinstance(x, Variable) and x.name:
name = re.sub("[^0-9a-zA-Z]+", "_", x.name)
name = (
name
if (
......@@ -622,19 +622,22 @@ def get_name_for_object(x: Any) -> str:
and not iskeyword(name)
and name not in dir(builtins)
)
else x.auto_name
else ""
)
else:
name = getattr(x, "__name__", "")
name = re.sub(r"(?<!^)(?=[A-Z])", "_", getattr(x, "__name__", "")).lower()
if not name or (not name.isidentifier() or iskeyword(name)):
name = type(x).__name__
# Try to get snake-case out of the type name
name = re.sub(r"(?<!^)(?=[A-Z])", "_", type(x).__name__).lower()
assert name.isidentifier() and not iskeyword(name)
return name
def unique_name_generator(
external_names: Optional[List[str]] = None, suffix_sep: str = ""
external_names: Optional[List[str]] = None, suffix_sep: str = "_"
) -> Callable:
"""Create a function that generates unique names."""
......
......@@ -89,23 +89,6 @@ def compare_jax_and_py(
return jax_res
def test_jax_FunctionGraph_names():
import inspect
from aesara.link.jax.dispatch import jax_funcify
x = scalar("1x")
y = scalar("_")
z = scalar()
q = scalar("def")
out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False)
out_jx = jax_funcify(out_fg)
sig = inspect.signature(out_jx)
assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys())
assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4)
def test_jax_FunctionGraph_once():
"""Make sure that an output is only computed once when it's referenced multiple times."""
from aesara.link.jax.dispatch import jax_funcify
......
......@@ -220,7 +220,6 @@ class TestWrapLinker:
def test_sort_schedule_fn():
import aesara
from aesara.graph.sched import make_depends, sort_schedule_fn
x = matrix("x")
......
......@@ -12,7 +12,7 @@ from aesara.link.utils import (
get_name_for_object,
unique_name_generator,
)
from aesara.scalar.basic import Add
from aesara.scalar.basic import Add, float64
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar, vector
from aesara.tensor.type_other import NoneConst
......@@ -42,7 +42,7 @@ def test_fgraph_to_python_names():
x = scalar("1x")
y = scalar("_")
z = scalar()
z = float64()
q = scalar("def")
r = NoneConst
......@@ -50,9 +50,13 @@ def test_fgraph_to_python_names():
out_jx = fgraph_to_python(out_fg, to_python)
sig = inspect.signature(out_jx)
assert (x.auto_name, "_", z.auto_name, q.auto_name, r.name) == tuple(
sig.parameters.keys()
)
assert (
"tensor_variable",
"_",
"scalar_variable",
"tensor_variable_1",
r.name,
) == tuple(sig.parameters.keys())
assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5)
obj = object()
......@@ -191,7 +195,7 @@ def test_unique_name_generator():
q_name_1 = unique_names(q)
q_name_2 = unique_names(q)
assert q_name_1 == q_name_2 == q.auto_name
assert q_name_1 == q_name_2 == "tensor_variable"
unique_names = unique_name_generator()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论