提交 205b7a84 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Stop using auto_name for transpilation

Using the `auto_name` values will result in cache misses when caching is based on the generated source code, so we're not going to use it.
上级 071d3cae
...@@ -393,10 +393,14 @@ def numba_funcify_FunctionGraph( ...@@ -393,10 +393,14 @@ def numba_funcify_FunctionGraph(
def create_index_func(node, objmode=False): def create_index_func(node, objmode=False):
"""Create a Python function that assembles and uses an index on an array.""" """Create a Python function that assembles and uses an index on an array."""
unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)
def convert_indices(indices, entry): def convert_indices(indices, entry):
if indices and isinstance(entry, Type): if indices and isinstance(entry, Type):
rval = indices.pop(0) rval = indices.pop(0)
return rval.auto_name return unique_names(rval)
elif isinstance(entry, slice): elif isinstance(entry, slice):
return ( return (
f"slice({convert_indices(indices, entry.start)}, " f"slice({convert_indices(indices, entry.start)}, "
...@@ -413,10 +417,6 @@ def create_index_func(node, objmode=False): ...@@ -413,10 +417,6 @@ def create_index_func(node, objmode=False):
) )
index_start_idx = 1 + int(set_or_inc) index_start_idx = 1 + int(set_or_inc)
unique_names = unique_name_generator(
["subtensor", "incsubtensor", "z"], suffix_sep="_"
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs] input_names = [unique_names(v, force_unique=True) for v in node.inputs]
op_indices = list(node.inputs[index_start_idx:]) op_indices = list(node.inputs[index_start_idx:])
idx_list = getattr(node.op, "idx_list", None) idx_list = getattr(node.op, "idx_list", None)
......
...@@ -54,7 +54,7 @@ def numba_funcify_Scan(op, node, **kwargs): ...@@ -54,7 +54,7 @@ def numba_funcify_Scan(op, node, **kwargs):
p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs p_outer_in_nit_sot = p_outer_in_shared + n_shared_outs
p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot p_outer_in_non_seqs = p_outer_in_nit_sot + n_nit_sot
input_names = [f"{n.auto_name}_{i}" for i, n in enumerate(node.inputs[1:])] input_names = [f"outer_in_{i}" for i, n in enumerate(node.inputs[1:])]
outer_in_seqs_names = input_names[:n_seqs] outer_in_seqs_names = input_names[:n_seqs]
outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot] outer_in_mit_mot_names = input_names[p_in_mit_mot : p_in_mit_mot + n_mit_mot]
outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot] outer_in_mit_sot_names = input_names[p_in_mit_sot : p_in_mit_sot + n_mit_sot]
......
...@@ -613,8 +613,8 @@ def compile_function_src( ...@@ -613,8 +613,8 @@ def compile_function_src(
def get_name_for_object(x: Any) -> str: def get_name_for_object(x: Any) -> str:
"""Get the name for an arbitrary object.""" """Get the name for an arbitrary object."""
if isinstance(x, Variable): if isinstance(x, Variable) and x.name:
name = re.sub("[^0-9a-zA-Z]+", "_", x.name) if x.name else "" name = re.sub("[^0-9a-zA-Z]+", "_", x.name)
name = ( name = (
name name
if ( if (
...@@ -622,19 +622,22 @@ def get_name_for_object(x: Any) -> str: ...@@ -622,19 +622,22 @@ def get_name_for_object(x: Any) -> str:
and not iskeyword(name) and not iskeyword(name)
and name not in dir(builtins) and name not in dir(builtins)
) )
else x.auto_name else ""
) )
else: else:
name = getattr(x, "__name__", "") name = re.sub(r"(?<!^)(?=[A-Z])", "_", getattr(x, "__name__", "")).lower()
if not name or (not name.isidentifier() or iskeyword(name)): if not name or (not name.isidentifier() or iskeyword(name)):
name = type(x).__name__ # Try to get snake-case out of the type name
name = re.sub(r"(?<!^)(?=[A-Z])", "_", type(x).__name__).lower()
assert name.isidentifier() and not iskeyword(name)
return name return name
def unique_name_generator( def unique_name_generator(
external_names: Optional[List[str]] = None, suffix_sep: str = "" external_names: Optional[List[str]] = None, suffix_sep: str = "_"
) -> Callable: ) -> Callable:
"""Create a function that generates unique names.""" """Create a function that generates unique names."""
......
...@@ -89,23 +89,6 @@ def compare_jax_and_py( ...@@ -89,23 +89,6 @@ def compare_jax_and_py(
return jax_res return jax_res
def test_jax_FunctionGraph_names():
import inspect
from aesara.link.jax.dispatch import jax_funcify
x = scalar("1x")
y = scalar("_")
z = scalar()
q = scalar("def")
out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False)
out_jx = jax_funcify(out_fg)
sig = inspect.signature(out_jx)
assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys())
assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4)
def test_jax_FunctionGraph_once(): def test_jax_FunctionGraph_once():
"""Make sure that an output is only computed once when it's referenced multiple times.""" """Make sure that an output is only computed once when it's referenced multiple times."""
from aesara.link.jax.dispatch import jax_funcify from aesara.link.jax.dispatch import jax_funcify
......
...@@ -220,7 +220,6 @@ class TestWrapLinker: ...@@ -220,7 +220,6 @@ class TestWrapLinker:
def test_sort_schedule_fn(): def test_sort_schedule_fn():
import aesara
from aesara.graph.sched import make_depends, sort_schedule_fn from aesara.graph.sched import make_depends, sort_schedule_fn
x = matrix("x") x = matrix("x")
......
...@@ -12,7 +12,7 @@ from aesara.link.utils import ( ...@@ -12,7 +12,7 @@ from aesara.link.utils import (
get_name_for_object, get_name_for_object,
unique_name_generator, unique_name_generator,
) )
from aesara.scalar.basic import Add from aesara.scalar.basic import Add, float64
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar, vector from aesara.tensor.type import scalar, vector
from aesara.tensor.type_other import NoneConst from aesara.tensor.type_other import NoneConst
...@@ -42,7 +42,7 @@ def test_fgraph_to_python_names(): ...@@ -42,7 +42,7 @@ def test_fgraph_to_python_names():
x = scalar("1x") x = scalar("1x")
y = scalar("_") y = scalar("_")
z = scalar() z = float64()
q = scalar("def") q = scalar("def")
r = NoneConst r = NoneConst
...@@ -50,9 +50,13 @@ def test_fgraph_to_python_names(): ...@@ -50,9 +50,13 @@ def test_fgraph_to_python_names():
out_jx = fgraph_to_python(out_fg, to_python) out_jx = fgraph_to_python(out_fg, to_python)
sig = inspect.signature(out_jx) sig = inspect.signature(out_jx)
assert (x.auto_name, "_", z.auto_name, q.auto_name, r.name) == tuple( assert (
sig.parameters.keys() "tensor_variable",
) "_",
"scalar_variable",
"tensor_variable_1",
r.name,
) == tuple(sig.parameters.keys())
assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5) assert (1, 2, 3, 4, 5) == out_jx(1, 2, 3, 4, 5)
obj = object() obj = object()
...@@ -191,7 +195,7 @@ def test_unique_name_generator(): ...@@ -191,7 +195,7 @@ def test_unique_name_generator():
q_name_1 = unique_names(q) q_name_1 = unique_names(q)
q_name_2 = unique_names(q) q_name_2 = unique_names(q)
assert q_name_1 == q_name_2 == q.auto_name assert q_name_1 == q_name_2 == "tensor_variable"
unique_names = unique_name_generator() unique_names = unique_name_generator()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论