提交 c48a8b3a authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move `create_tuple_string` to `string_codegen` module

Removes depracated `create_tuple_creator`
上级 cc39ea30
......@@ -171,40 +171,6 @@ def create_numba_signature(
return numba.types.void(*input_types)
def create_tuple_creator(f, n):
"""Construct a compile-time ``tuple``-comprehension-like loop.
See https://github.com/numba/numba/issues/2771#issuecomment-414358902
"""
warnings.warn(
"create_tuple_creator is deprecated and will be removed in a future release",
FutureWarning,
)
assert n > 0
f = numba_njit(f)
@numba_njit
def creator(args):
return (f(0, *args),)
for i in range(1, n):
@numba_njit
def creator(args, creator=creator, i=i):
return (*creator(args), f(i, *args))
return numba_njit(lambda *args: creator(args))
def create_tuple_string(x):
if len(x) == 1:
return f"({x[0]},)"
else:
return f"({', '.join(x)})"
@numba.extending.intrinsic
def direct_cast(typingctx, val, typ):
if isinstance(typ, numba.types.TypeRef):
......
......@@ -14,11 +14,11 @@ from pytensor.link.numba.cache import (
)
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
numba_funcify_and_cache_key,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.string_codegen import create_tuple_string
from pytensor.link.numba.dispatch.vectorize_codegen import (
_vectorized,
encode_literals,
......
......@@ -11,10 +11,10 @@ from pytensor.compile.mode import NUMBA, get_mode
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
numba_funcify_and_cache_key,
register_funcify_and_cache_key,
)
from pytensor.link.numba.dispatch.string_codegen import create_tuple_string
from pytensor.scan.op import Scan
from pytensor.tensor.type import TensorType
......
def create_tuple_string(x):
if len(x) == 1:
return f"({x[0]},)"
else:
return f"({', '.join(x)})"
......@@ -15,12 +15,12 @@ from pytensor.link.numba.cache import (
compile_numba_function_src,
)
from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
generate_fallback_impl,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.compile_ops import numba_deepcopy
from pytensor.link.numba.dispatch.string_codegen import create_tuple_string
from pytensor.tensor import TensorType
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
......
......@@ -6,10 +6,10 @@ import numpy as np
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
create_tuple_string,
register_funcify_and_cache_key,
register_funcify_default_op_cache_key,
)
from pytensor.link.numba.dispatch.string_codegen import create_tuple_string
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论