提交 351ce53e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Remove `to_scalar` helper

上级 4de2a7e6
......@@ -3,8 +3,7 @@ from functools import singledispatch
import numba
import numpy as np
from numba import types
from numba.core.errors import NumbaWarning, TypingError
from numba.core.errors import NumbaWarning
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from pytensor import In, config
......@@ -135,20 +134,6 @@ def create_numba_signature(
return numba.types.void(*input_types)
def to_scalar(x):
return np.asarray(x).item()
@numba.extending.overload(to_scalar)
def impl_to_scalar(x):
if isinstance(x, numba.types.Number | numba.types.Boolean):
return lambda x: x
elif isinstance(x, numba.types.Array):
return lambda x: x.item()
else:
raise TypingError(f"{x} must be a scalar compatible type.")
def create_tuple_creator(f, n):
"""Construct a compile-time ``tuple``-comprehension-like loop.
......
......@@ -26,7 +26,7 @@ from pytensor.tensor.extra_ops import (
def numba_funcify_Bartlett(op, **kwargs):
@numba_basic.numba_njit(inline="always")
def bartlett(x):
return np.bartlett(numba_basic.to_scalar(x))
return np.bartlett(x.item())
return bartlett
......@@ -112,12 +112,12 @@ def numba_funcify_FillDiagonalOffset(op, node, **kwargs):
@numba_basic.numba_njit
def filldiagonaloffset(a, val, offset):
height, width = a.shape
offset_item = offset.item()
if offset >= 0:
start = numba_basic.to_scalar(offset)
start = offset_item
num_of_step = min(min(width, height), width - offset)
else:
start = -numba_basic.to_scalar(offset) * a.shape[1]
start = -offset_item * a.shape[1]
num_of_step = min(min(width, height), height + offset)
step = a.shape[1] + 1
......
......@@ -210,14 +210,10 @@ def numba_funcify_type_casting(op, **kwargs):
def numba_funcify_Clip(op, **kwargs):
@numba_basic.numba_njit
def clip(x, min_val, max_val):
x = numba_basic.to_scalar(x)
min_scalar = numba_basic.to_scalar(min_val)
max_scalar = numba_basic.to_scalar(max_val)
if x < min_scalar:
return min_scalar
elif x > max_scalar:
return max_scalar
if x < min_val:
return min_val
elif x > max_val:
return max_val
else:
return x
......
......@@ -365,7 +365,7 @@ def numba_funcify_Scan(op: Scan, node, **kwargs):
storage_alloc_stmts.append(
dedent(
f"""
{storage_size_name} = to_numba_scalar({outer_in_name})
{storage_size_name} = ({outer_in_name}).item()
{storage_name} = np.empty({storage_shape}, dtype=np.{storage_dtype})
"""
).strip()
......@@ -435,10 +435,9 @@ def scan({", ".join(outer_in_names)}):
"""
global_env = {
"np": np,
"scan_inner_func": scan_inner_func,
"to_numba_scalar": numba_basic.to_scalar,
}
global_env["np"] = np
scan_op_fn = compile_function_src(scan_op_src, "scan", {**globals(), **global_env})
......
......@@ -28,18 +28,17 @@ from pytensor.tensor.basic import (
def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = {
"np": np,
"to_scalar": numba_basic.to_scalar,
"dtype": np.dtype(op.dtype),
}
unique_names = unique_name_generator(
["np", "to_scalar", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
["np", "dtype", "allocempty", "scalar_shape"], suffix_sep="_"
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
f"{item_name} = to_scalar({shape_name})"
f"{item_name} = {shape_name}.item()"
for item_name, shape_name in zip(
shape_var_item_names, shape_var_names, strict=True
)
......@@ -63,10 +62,10 @@ def allocempty({", ".join(shape_var_names)}):
@numba_funcify.register(Alloc)
def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np, "to_scalar": numba_basic.to_scalar}
global_env = {"np": np}
unique_names = unique_name_generator(
["np", "to_scalar", "alloc", "val_np", "val", "scalar_shape", "res"],
["np", "alloc", "val_np", "val", "scalar_shape", "res"],
suffix_sep="_",
)
shape_var_names = [unique_names(v, force_unique=True) for v in node.inputs[1:]]
......@@ -110,9 +109,9 @@ def numba_funcify_ARange(op, **kwargs):
@numba_basic.numba_njit(inline="always")
def arange(start, stop, step):
return np.arange(
numba_basic.to_scalar(start),
numba_basic.to_scalar(stop),
numba_basic.to_scalar(step),
start.item(),
stop.item(),
step.item(),
dtype=dtype,
)
......@@ -187,9 +186,9 @@ def numba_funcify_Eye(op, **kwargs):
@numba_basic.numba_njit(inline="always")
def eye(N, M, k):
return np.eye(
numba_basic.to_scalar(N),
numba_basic.to_scalar(M),
numba_basic.to_scalar(k),
N.item(),
M.item(),
k.item(),
dtype=dtype,
)
......@@ -200,16 +199,16 @@ def numba_funcify_Eye(op, **kwargs):
def numba_funcify_MakeVector(op, node, **kwargs):
dtype = np.dtype(op.dtype)
global_env = {"np": np, "to_scalar": numba_basic.to_scalar, "dtype": dtype}
global_env = {"np": np, "dtype": dtype}
unique_names = unique_name_generator(
["np", "to_scalar"],
["np"],
suffix_sep="_",
)
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
def create_list_string(x):
args = ", ".join([f"to_scalar({i})" for i in x] + ([""] if len(x) == 1 else []))
args = ", ".join([f"{i}.item()" for i in x] + ([""] if len(x) == 1 else []))
return f"[{args}]"
makevector_def_src = f"""
......@@ -237,7 +236,7 @@ def numba_funcify_TensorFromScalar(op, **kwargs):
def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba_basic.numba_njit(inline="always")
def scalar_from_tensor(x):
return numba_basic.to_scalar(x)
return x.item()
return scalar_from_tensor
......
......@@ -134,12 +134,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
ll[i] = v
return tuple(ll)
def py_to_scalar(x):
if isinstance(x, np.ndarray):
return x.item()
else:
return x
def njit_noop(*args, **kwargs):
if len(args) == 1 and callable(args[0]):
return args[0]
......@@ -155,7 +149,6 @@ def eval_python_only(fn_inputs, fn_outputs, inputs, mode=numba_mode):
mock.patch(
"pytensor.link.numba.dispatch.basic.direct_cast", lambda x, dtype: x
),
mock.patch("pytensor.link.numba.dispatch.basic.to_scalar", py_to_scalar),
mock.patch(
"pytensor.link.numba.dispatch.basic.numba.np.numpy_support.from_dtype",
lambda dtype: dtype,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论