提交 351ce53e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove `to_scalar` helper

上级 4de2a7e6
...@@ -3,8 +3,7 @@ from functools import singledispatch ...@@ -3,8 +3,7 @@ from functools import singledispatch
import numba import numba
import numpy as np import numpy as np
from numba import types from numba.core.errors import NumbaWarning
from numba.core.errors import NumbaWarning, TypingError
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from pytensor import In, config from pytensor import In, config
...@@ -135,20 +134,6 @@ def create_numba_signature( ...@@ -135,20 +134,6 @@ def create_numba_signature(
return numba.types.void(*input_types) return numba.types.void(*input_types)
def to_scalar(x):
return np.asarray(x).item()
@numba.extending.overload(to_scalar)
def impl_to_scalar(x):
if isinstance(x, numba.types.Number | numba.types.Boolean):
return lambda x: x
elif isinstance(x, numba.types.Array):
return lambda x: x.item()
else:
raise TypingError(f"{x} must be a scalar compatible type.")
def create_tuple_creator(f, n): def create_tuple_creator(f, n):
"""Construct a compile-time ``tuple``-comprehension-like loop. """Construct a compile-time ``tuple``-comprehension-like loop.
......
...@@ -26,7 +26,7 @@ from pytensor.tensor.extra_ops import ( ...@@ -26,7 +26,7 @@ from pytensor.tensor.extra_ops import (
def numba_funcify_Bartlett(op, **kwargs): def numba_funcify_Bartlett(op, **kwargs):
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def bartlett(x): def bartlett(x):
return np.bartlett(numba_basic.to_scalar(x)) return np.bartlett(x.item())
return bartlett return bartlett
...@@ -112,12 +112,12 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs): ...@@ -112,12 +112,12 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def filldiagonaloffset(a, val, offset): def filldiagonaloffset(a, val, offset):
height, width = a.shape height, width = a.shape
offset_item = offset.item()
if offset >= 0: if offset >= 0:
start = numba_basic.to_scalar(offset) start = offset_item
num_of_step = min(min(width, height), width - offset) num_of_step = min(min(width, height), width - offset)
else: else:
start = -numba_basic.to_scalar(offset) * a.shape[1] start = -offset_item * a.shape[1]
num_of_step = min(min(width, height), height + offset) num_of_step = min(min(width, height), height + offset)
step = a.shape[1] + 1 step = a.shape[1] + 1
......
...@@ -210,14 +210,10 @@ def numba_funcify_type_casting(op, **kwargs): ...@@ -210,14 +210,10 @@ def numba_funcify_type_casting(op, **kwargs):
def numba_funcify_Clip(op, **kwargs): def numba_funcify_Clip(op, **kwargs):
@numba_basic.numba_njit @numba_basic.numba_njit
def clip(x, min_val, max_val): def clip(x, min_val, max_val):
x = numba_basic.to_scalar(x) if x < min_val:
min_scalar = numba_basic.to_scalar(min_val) return min_val
max_scalar = numba_basic.to_scalar(max_val) elif x > max_val:
return max_val
if x < min_scalar:
return min_scalar
elif x > max_scalar:
return max_scalar
else: else:
return x return x
......
...@@ -365,7 +365,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs): ...@@ -365,7 +365,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
storage_alloc_stmts.append( storage_alloc_stmts.append(
dedent( dedent(
f""" f"""
{storage_size_name} = to_numba_scalar({outer_in_name}) {storage_size_name} = ({outer_in_name}).item()
{storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype}) {storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype})
""" """
).strip() ).strip()
...@@ -435,10 +435,9 @@ def scan({", ".join(outer_in_names)}): ...@@ -435,10 +435,9 @@ def scan({", ".join(outer_in_names)}):
""" """
global_env = { global_env = {
"np": np,
"scan_inner_func": scan_inner_func, "scan_inner_func": scan_inner_func,
"to_numba_scalar": numba_basic.to_scalar,
} }
global_env["np"] = np
scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env}) scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env})
......
...@@ -28,18 +28,17 @@ from pytensor.tensor.basic import ( ...@@ -28,18 +28,17 @@ from pytensor.tensor.basic import (
def numba_funcify_AllocEmpty(op, node, **kwargs): def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = { global_env = {
"np": np, "np": np,
"to_scalar": numba_basic.to_scalar,
"dtype": np.dtype(op.dtype), "dtype": np.dtype(op.dtype),
} }
unique_names = unique_name_generator( unique_names = unique_name_generator(
["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_" ["np", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
) )
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs] shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
shape_var_item_names = [f"{name}_item" for name in shape_var_names] shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent( shapes_to_items_src = indent(
"\n".join( "\n".join(
f"{item_name} = to_scalar({shape_name})" f"{item_name} = {shape_name}.item()"
for item_name, shape_name in zip( for item_name, shape_name in zip(
shape_var_item_names, shape_var_names, strict=True shape_var_item_names, shape_var_names, strict=True
) )
...@@ -63,10 +62,10 @@ def allocempty({", ".join(shape_var_names)}): ...@@ -63,10 +62,10 @@ def allocempty({", ".join(shape_var_names)}):
@numba_funcify.register(Alloc) @numba_funcify.register(Alloc)
def numba_funcify_Alloc(op, node, **kwargs): def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np, "to_scalar": numba_basic.to_scalar} global_env = {"np": np}
unique_names = unique_name_generator( unique_names = unique_name_generator(
["np", "to_scalar", "alloc", "val_np", "val", "scalar_shape", "res"], ["np", "alloc", "val_np", "val", "scalar_shape", "res"],
suffix_sep="_", suffix_sep="_",
) )
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]] shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
...@@ -110,9 +109,9 @@ def numba_funcify_ARange(op, **kwargs): ...@@ -110,9 +109,9 @@ def numba_funcify_ARange(op, **kwargs):
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def arange(start, stop, step): def arange(start, stop, step):
return np.arange( return np.arange(
numba_basic.to_scalar(start), start.item(),
numba_basic.to_scalar(stop), stop.item(),
numba_basic.to_scalar(step), step.item(),
dtype=dtype, dtype=dtype,
) )
...@@ -187,9 +186,9 @@ def numba_funcify_Eye(op, **kwargs): ...@@ -187,9 +186,9 @@ def numba_funcify_Eye(op, **kwargs):
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def eye(N, M, k): def eye(N, M, k):
return np.eye( return np.eye(
numba_basic.to_scalar(N), N.item(),
numba_basic.to_scalar(M), M.item(),
numba_basic.to_scalar(k), k.item(),
dtype=dtype, dtype=dtype,
) )
...@@ -200,16 +199,16 @@ def numba_funcify_Eye(op, **kwargs): ...@@ -200,16 +199,16 @@ def numba_funcify_Eye(op, **kwargs):
def numba_funcify_MakeVector(op, node, **kwargs): def numba_funcify_MakeVector(op, node, **kwargs):
dtype = np.dtype(op.dtype) dtype = np.dtype(op.dtype)
global_env = {"np": np, "to_scalar": numba_basic.to_scalar, "dtype": dtype} global_env = {"np": np, "dtype": dtype}
unique_names = unique_name_generator( unique_names = unique_name_generator(
["np", "to_scalar"], ["np"],
suffix_sep="_", suffix_sep="_",
) )
input_names = [unique_names(v, force_unique=True) for v in node.inputs] input_names = [unique_names(v, force_unique=True) for v in node.inputs]
def create_list_string(x): def create_list_string(x):
args = ", ".join([f"to_scalar({i})" for i in x] + ([""] if len(x) == 1 else [])) args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else []))
return f"[{args}]" return f"[{args}]"
makevector_def_src = f""" makevector_def_src = f"""
...@@ -237,7 +236,7 @@ def numba_funcify_TensorFromScalar(op, **kwargs): ...@@ -237,7 +236,7 @@ def numba_funcify_TensorFromScalar(op, **kwargs):
def numba_funcify_ScalarFromTensor(op, **kwargs): def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba_basic.numba_njit(inline="always") @numba_basic.numba_njit(inline="always")
def scalar_from_tensor(x): def scalar_from_tensor(x):
return numba_basic.to_scalar(x) return x.item()
return scalar_from_tensor return scalar_from_tensor
......
...@@ -134,12 +134,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode): ...@@ -134,12 +134,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
ll[i] = v ll[i] = v
return tuple(ll) return tuple(ll)
def py_to_scalar(x):
if isinstance(x, np.ndarray):
return x.item()
else:
return x
def njit_noop(*args, **kwargs): def njit_noop(*args, **kwargs):
if len(args) == 1 and callable(args[0]): if len(args) == 1 and callable(args[0]):
return args[0] return args[0]
...@@ -155,7 +149,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode): ...@@ -155,7 +149,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
mock.patch( mock.patch(
"pytensor.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x "pytensor.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x
), ),
mock.patch("pytensor.link.numba.dispatch.basic.to_scalar", py_to_scalar),
mock.patch( mock.patch(
"pytensor.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype", "pytensor.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
lambda dtype: dtype, lambda dtype: dtype,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论