提交 aad6fb75 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Deprecate `pytensor.get_underlying_scalar_constant`

上级 a120dc27
...@@ -24,6 +24,7 @@ __docformat__ = "restructuredtext en" ...@@ -24,6 +24,7 @@ __docformat__ = "restructuredtext en"
# pytensor code, since this code may want to log some messages. # pytensor code, since this code may want to log some messages.
import logging import logging
import sys import sys
import warnings
from functools import singledispatch from functools import singledispatch
from pathlib import Path from pathlib import Path
from typing import Any, NoReturn, Optional from typing import Any, NoReturn, Optional
...@@ -148,13 +149,13 @@ def get_underlying_scalar_constant(v): ...@@ -148,13 +149,13 @@ def get_underlying_scalar_constant(v):
If `v` is not some view of constant data, then raise a If `v` is not some view of constant data, then raise a
`NotScalarConstantError`. `NotScalarConstantError`.
""" """
# Is it necessary to test for presence of pytensor.sparse at runtime? warnings.warn(
sparse = globals().get("sparse") "get_underlying_scalar_constant is deprecated. Use tensor.get_underlying_scalar_constant_value instead.",
if sparse and isinstance(v.type, sparse.SparseTensorType): FutureWarning,
if v.owner is not None and isinstance(v.owner.op, sparse.CSM): )
data = v.owner.inputs[0] from pytensor.tensor.basic import get_underlying_scalar_constant_value
return tensor.get_underlying_scalar_constant_value(data)
return tensor.get_underlying_scalar_constant_value(v) return get_underlying_scalar_constant_value(v)
# isort: off # isort: off
......
...@@ -1329,7 +1329,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1329,7 +1329,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
f" {i}. Since this input is only connected " f" {i}. Since this input is only connected "
"to integer-valued outputs, it should " "to integer-valued outputs, it should "
"evaluate to zeros, but it evaluates to" "evaluate to zeros, but it evaluates to"
f"{pytensor.get_underlying_scalar_constant(term)}." f"{pytensor.get_underlying_scalar_constant_value(term)}."
) )
raise ValueError(msg) raise ValueError(msg)
...@@ -2157,6 +2157,9 @@ def _is_zero(x): ...@@ -2157,6 +2157,9 @@ def _is_zero(x):
'maybe' means that x is an expression that is complicated enough 'maybe' means that x is an expression that is complicated enough
that we can't tell that it simplifies to 0. that we can't tell that it simplifies to 0.
""" """
from pytensor.tensor import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
if not hasattr(x, "type"): if not hasattr(x, "type"):
return np.all(x == 0.0) return np.all(x == 0.0)
if isinstance(x.type, NullType): if isinstance(x.type, NullType):
...@@ -2166,9 +2169,9 @@ def _is_zero(x): ...@@ -2166,9 +2169,9 @@ def _is_zero(x):
no_constant_value = True no_constant_value = True
try: try:
constant_value = pytensor.get_underlying_scalar_constant(x) constant_value = get_underlying_scalar_constant_value(x)
no_constant_value = False no_constant_value = False
except pytensor.tensor.exceptions.NotScalarConstantError: except NotScalarConstantError:
pass pass
if no_constant_value: if no_constant_value:
......
...@@ -320,6 +320,8 @@ def get_underlying_scalar_constant_value( ...@@ -320,6 +320,8 @@ def get_underlying_scalar_constant_value(
""" """
from pytensor.compile.ops import DeepCopyOp, OutputGuard from pytensor.compile.ops import DeepCopyOp, OutputGuard
from pytensor.sparse import CSM
from pytensor.tensor.subtensor import Subtensor
v = orig_v v = orig_v
while True: while True:
...@@ -350,16 +352,16 @@ def get_underlying_scalar_constant_value( ...@@ -350,16 +352,16 @@ def get_underlying_scalar_constant_value(
raise NotScalarConstantError() raise NotScalarConstantError()
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0: if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
op = v.owner.op
max_recur -= 1 max_recur -= 1
if isinstance( if isinstance(
v.owner.op, op, Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp
Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp,
): ):
# OutputGuard is only used in debugmode but we # OutputGuard is only used in debugmode but we
# keep it here to avoid problems with old pickles # keep it here to avoid problems with old pickles
v = v.owner.inputs[0] v = v.owner.inputs[0]
continue continue
elif isinstance(v.owner.op, Shape_i): elif isinstance(op, Shape_i):
i = v.owner.op.i i = v.owner.op.i
inp = v.owner.inputs[0] inp = v.owner.inputs[0]
if isinstance(inp, Constant): if isinstance(inp, Constant):
...@@ -373,10 +375,10 @@ def get_underlying_scalar_constant_value( ...@@ -373,10 +375,10 @@ def get_underlying_scalar_constant_value(
# mess with the stabilization optimization and be too slow. # mess with the stabilization optimization and be too slow.
# We put all the scalar Ops used by get_canonical_form_slice() # We put all the scalar Ops used by get_canonical_form_slice()
# to allow it to determine the broadcast pattern correctly. # to allow it to determine the broadcast pattern correctly.
elif isinstance(v.owner.op, ScalarFromTensor | TensorFromScalar): elif isinstance(op, ScalarFromTensor | TensorFromScalar):
v = v.owner.inputs[0] v = v.owner.inputs[0]
continue continue
elif isinstance(v.owner.op, CheckAndRaise): elif isinstance(op, CheckAndRaise):
# check if all conditions are constant and true # check if all conditions are constant and true
conds = [ conds = [
get_underlying_scalar_constant_value(c, max_recur=max_recur) get_underlying_scalar_constant_value(c, max_recur=max_recur)
...@@ -385,7 +387,7 @@ def get_underlying_scalar_constant_value( ...@@ -385,7 +387,7 @@ def get_underlying_scalar_constant_value(
if builtins.all(0 == c.ndim and c != 0 for c in conds): if builtins.all(0 == c.ndim and c != 0 for c in conds):
v = v.owner.inputs[0] v = v.owner.inputs[0]
continue continue
elif isinstance(v.owner.op, ps.ScalarOp): elif isinstance(op, ps.ScalarOp):
if isinstance(v.owner.op, ps.Second): if isinstance(v.owner.op, ps.Second):
# We don't need both input to be constant for second # We don't need both input to be constant for second
shp, val = v.owner.inputs shp, val = v.owner.inputs
...@@ -402,7 +404,7 @@ def get_underlying_scalar_constant_value( ...@@ -402,7 +404,7 @@ def get_underlying_scalar_constant_value(
# In fast_compile, we don't enable local_fill_to_alloc, so # In fast_compile, we don't enable local_fill_to_alloc, so
# we need to investigate Second as Alloc. So elemwise # we need to investigate Second as Alloc. So elemwise
# don't disable the check for Second. # don't disable the check for Second.
elif isinstance(v.owner.op, Elemwise): elif isinstance(op, Elemwise):
if isinstance(v.owner.op.scalar_op, ps.Second): if isinstance(v.owner.op.scalar_op, ps.Second):
# We don't need both input to be constant for second # We don't need both input to be constant for second
shp, val = v.owner.inputs shp, val = v.owner.inputs
...@@ -418,10 +420,7 @@ def get_underlying_scalar_constant_value( ...@@ -418,10 +420,7 @@ def get_underlying_scalar_constant_value(
ret = [[None]] ret = [[None]]
v.owner.op.perform(v.owner, const, ret) v.owner.op.perform(v.owner, const, ret)
return np.asarray(ret[0][0].copy()) return np.asarray(ret[0][0].copy())
elif ( elif isinstance(op, Subtensor) and v.ndim == 0:
isinstance(v.owner.op, pytensor.tensor.subtensor.Subtensor)
and v.ndim == 0
):
if isinstance(v.owner.inputs[0], TensorConstant): if isinstance(v.owner.inputs[0], TensorConstant):
from pytensor.tensor.subtensor import get_constant_idx from pytensor.tensor.subtensor import get_constant_idx
...@@ -545,6 +544,14 @@ def get_underlying_scalar_constant_value( ...@@ -545,6 +544,14 @@ def get_underlying_scalar_constant_value(
if isinstance(grandparent, Constant): if isinstance(grandparent, Constant):
return np.asarray(np.shape(grandparent.data)[idx]) return np.asarray(np.shape(grandparent.data)[idx])
elif isinstance(op, CSM):
data = get_underlying_scalar_constant_value(
v.owner.inputs, elemwise=elemwise, max_recur=max_recur
)
# Sparse variable can only be constant if zero (or I guess if homogeneously dense)
if data == 0:
return data
break
raise NotScalarConstantError() raise NotScalarConstantError()
...@@ -4071,7 +4078,7 @@ class Choose(Op): ...@@ -4071,7 +4078,7 @@ class Choose(Op):
static_out_shape = () static_out_shape = ()
for s in out_shape: for s in out_shape:
try: try:
s_val = pytensor.get_underlying_scalar_constant(s) s_val = get_underlying_scalar_constant_value(s)
except (NotScalarConstantError, AttributeError): except (NotScalarConstantError, AttributeError):
s_val = None s_val = None
......
...@@ -19,7 +19,7 @@ from pytensor.graph.replace import vectorize_node ...@@ -19,7 +19,7 @@ from pytensor.graph.replace import vectorize_node
from pytensor.link.basic import PerformLinker from pytensor.link.basic import PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import second from pytensor.tensor.basic import get_scalar_constant_value, second
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.math import Any, Sum, exp from pytensor.tensor.math import Any, Sum, exp
from pytensor.tensor.math import all as pt_all from pytensor.tensor.math import all as pt_all
...@@ -807,8 +807,8 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -807,8 +807,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert len(res_shape) == 1 assert len(res_shape) == 1
assert len(res_shape[0]) == 2 assert len(res_shape[0]) == 2
assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1 assert get_scalar_constant_value(res_shape[0][0]) == 1
assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1 assert get_scalar_constant_value(res_shape[0][1]) == 1
def test_infer_shape_multi_output(self): def test_infer_shape_multi_output(self):
class CustomElemwise(Elemwise): class CustomElemwise(Elemwise):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论