提交 aad6fb75 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Deprecate `pytensor.get_underlying_scalar_constant`

上级 a120dc27
......@@ -24,6 +24,7 @@ __docformat__ = "restructuredtext en"
# pytensor code, since this code may want to log some messages.
import logging
import sys
import warnings
from functools import singledispatch
from pathlib import Path
from typing import Any, NoReturn, Optional
......@@ -148,13 +149,13 @@ def get_underlying_scalar_constant(v):
If `v` is not some view of constant data, then raise a
`NotScalarConstantError`.
"""
# Is it necessary to test for presence of pytensor.sparse at runtime?
sparse = globals().get("sparse")
if sparse and isinstance(v.type, sparse.SparseTensorType):
if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
data = v.owner.inputs[0]
return tensor.get_underlying_scalar_constant_value(data)
return tensor.get_underlying_scalar_constant_value(v)
warnings.warn(
"get_underlying_scalar_constant is deprecated. Use tensor.get_underlying_scalar_constant_value instead.",
FutureWarning,
)
from pytensor.tensor.basic import get_underlying_scalar_constant_value
return get_underlying_scalar_constant_value(v)
# isort: off
......
......@@ -1329,7 +1329,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
f" {i}. Since this input is only connected "
"to integer-valued outputs, it should "
"evaluate to zeros, but it evaluates to"
f"{pytensor.get_underlying_scalar_constant(term)}."
f"{pytensor.get_underlying_scalar_constant_value(term)}."
)
raise ValueError(msg)
......@@ -2157,6 +2157,9 @@ def _is_zero(x):
'maybe' means that x is an expression that is complicated enough
that we can't tell that it simplifies to 0.
"""
from pytensor.tensor import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
if not hasattr(x, "type"):
return np.all(x == 0.0)
if isinstance(x.type, NullType):
......@@ -2166,9 +2169,9 @@ def _is_zero(x):
no_constant_value = True
try:
constant_value = pytensor.get_underlying_scalar_constant(x)
constant_value = get_underlying_scalar_constant_value(x)
no_constant_value = False
except pytensor.tensor.exceptions.NotScalarConstantError:
except NotScalarConstantError:
pass
if no_constant_value:
......
......@@ -320,6 +320,8 @@ def get_underlying_scalar_constant_value(
"""
from pytensor.compile.ops import DeepCopyOp, OutputGuard
from pytensor.sparse import CSM
from pytensor.tensor.subtensor import Subtensor
v = orig_v
while True:
......@@ -350,16 +352,16 @@ def get_underlying_scalar_constant_value(
raise NotScalarConstantError()
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
op = v.owner.op
max_recur -= 1
if isinstance(
v.owner.op,
Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp,
op, Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp
):
# OutputGuard is only used in debugmode but we
# keep it here to avoid problems with old pickles
v = v.owner.inputs[0]
continue
elif isinstance(v.owner.op, Shape_i):
elif isinstance(op, Shape_i):
i = v.owner.op.i
inp = v.owner.inputs[0]
if isinstance(inp, Constant):
......@@ -373,10 +375,10 @@ def get_underlying_scalar_constant_value(
# mess with the stabilization optimization and be too slow.
# We put all the scalar Ops used by get_canonical_form_slice()
# to allow it to determine the broadcast pattern correctly.
elif isinstance(v.owner.op, ScalarFromTensor | TensorFromScalar):
elif isinstance(op, ScalarFromTensor | TensorFromScalar):
v = v.owner.inputs[0]
continue
elif isinstance(v.owner.op, CheckAndRaise):
elif isinstance(op, CheckAndRaise):
# check if all conditions are constant and true
conds = [
get_underlying_scalar_constant_value(c, max_recur=max_recur)
......@@ -385,7 +387,7 @@ def get_underlying_scalar_constant_value(
if builtins.all(0 == c.ndim and c != 0 for c in conds):
v = v.owner.inputs[0]
continue
elif isinstance(v.owner.op, ps.ScalarOp):
elif isinstance(op, ps.ScalarOp):
if isinstance(v.owner.op, ps.Second):
# We don't need both input to be constant for second
shp, val = v.owner.inputs
......@@ -402,7 +404,7 @@ def get_underlying_scalar_constant_value(
# In fast_compile, we don't enable local_fill_to_alloc, so
# we need to investigate Second as Alloc. So elemwise
# don't disable the check for Second.
elif isinstance(v.owner.op, Elemwise):
elif isinstance(op, Elemwise):
if isinstance(v.owner.op.scalar_op, ps.Second):
# We don't need both input to be constant for second
shp, val = v.owner.inputs
......@@ -418,10 +420,7 @@ def get_underlying_scalar_constant_value(
ret = [[None]]
v.owner.op.perform(v.owner, const, ret)
return np.asarray(ret[0][0].copy())
elif (
isinstance(v.owner.op, pytensor.tensor.subtensor.Subtensor)
and v.ndim == 0
):
elif isinstance(op, Subtensor) and v.ndim == 0:
if isinstance(v.owner.inputs[0], TensorConstant):
from pytensor.tensor.subtensor import get_constant_idx
......@@ -545,6 +544,14 @@ def get_underlying_scalar_constant_value(
if isinstance(grandparent, Constant):
return np.asarray(np.shape(grandparent.data)[idx])
elif isinstance(op, CSM):
data = get_underlying_scalar_constant_value(
v.owner.inputs, elemwise=elemwise, max_recur=max_recur
)
# Sparse variable can only be constant if zero (or I guess if homogeneously dense)
if data == 0:
return data
break
raise NotScalarConstantError()
......@@ -4071,7 +4078,7 @@ class Choose(Op):
static_out_shape = ()
for s in out_shape:
try:
s_val = pytensor.get_underlying_scalar_constant(s)
s_val = get_underlying_scalar_constant_value(s)
except (NotScalarConstantError, AttributeError):
s_val = None
......
......@@ -19,7 +19,7 @@ from pytensor.graph.replace import vectorize_node
from pytensor.link.basic import PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import second
from pytensor.tensor.basic import get_scalar_constant_value, second
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Any, Sum, exp
from pytensor.tensor.math import all as pt_all
......@@ -807,8 +807,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert len(res_shape) == 1
assert len(res_shape[0]) == 2
assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1
assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1
assert get_scalar_constant_value(res_shape[0][0]) == 1
assert get_scalar_constant_value(res_shape[0][1]) == 1
def test_infer_shape_multi_output(self):
class CustomElemwise(Elemwise):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论