Unverified 提交 fda240fd authored 作者: Shreyas Singh's avatar Shreyas Singh 提交者: GitHub

`get_scalar_constant_value` now raises for non-scalar inputs (#248)

* Rename old get_scalar_constant_value to get_underlying_scalar_constant
上级 feccc417
...@@ -577,7 +577,7 @@ them perfectly, but a `dscalar` otherwise. ...@@ -577,7 +577,7 @@ them perfectly, but a `dscalar` otherwise.
.. method:: round(mode="half_away_from_zero") .. method:: round(mode="half_away_from_zero")
:noindex: :noindex:
.. method:: trace() .. method:: trace()
.. method:: get_scalar_constant_value() .. method:: get_underlying_scalar_constant_value()
.. method:: zeros_like(model, dtype=None) .. method:: zeros_like(model, dtype=None)
All the above methods are equivalent to NumPy for PyTensor on the current tensor. All the above methods are equivalent to NumPy for PyTensor on the current tensor.
......
...@@ -137,7 +137,7 @@ from pytensor.updates import OrderedUpdates ...@@ -137,7 +137,7 @@ from pytensor.updates import OrderedUpdates
# isort: on # isort: on
def get_scalar_constant_value(v): def get_underlying_scalar_constant(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`. """Return the constant scalar (i.e. 0-D) value underlying variable `v`.
If `v` is the output of dim-shuffles, fills, allocs, cast, etc. If `v` is the output of dim-shuffles, fills, allocs, cast, etc.
...@@ -153,8 +153,8 @@ def get_scalar_constant_value(v): ...@@ -153,8 +153,8 @@ def get_scalar_constant_value(v):
if sparse and isinstance(v.type, sparse.SparseTensorType): if sparse and isinstance(v.type, sparse.SparseTensorType):
if v.owner is not None and isinstance(v.owner.op, sparse.CSM): if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
data = v.owner.inputs[0] data = v.owner.inputs[0]
return tensor.get_scalar_constant_value(data) return tensor.get_underlying_scalar_constant_value(data)
return tensor.get_scalar_constant_value(v) return tensor.get_underlying_scalar_constant_value(v)
# isort: off # isort: off
......
...@@ -1325,7 +1325,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1325,7 +1325,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
f" {i}. Since this input is only connected " f" {i}. Since this input is only connected "
"to integer-valued outputs, it should " "to integer-valued outputs, it should "
"evaluate to zeros, but it evaluates to" "evaluate to zeros, but it evaluates to"
f"{pytensor.get_scalar_constant_value(term)}." f"{pytensor.get_underlying_scalar_constant(term)}."
) )
raise ValueError(msg) raise ValueError(msg)
...@@ -2086,7 +2086,7 @@ def _is_zero(x): ...@@ -2086,7 +2086,7 @@ def _is_zero(x):
no_constant_value = True no_constant_value = True
try: try:
constant_value = pytensor.get_scalar_constant_value(x) constant_value = pytensor.get_underlying_scalar_constant(x)
no_constant_value = False no_constant_value = False
except pytensor.tensor.exceptions.NotScalarConstantError: except pytensor.tensor.exceptions.NotScalarConstantError:
pass pass
......
...@@ -18,7 +18,7 @@ from pytensor.tensor.basic import ( ...@@ -18,7 +18,7 @@ from pytensor.tensor.basic import (
ScalarFromTensor, ScalarFromTensor,
Split, Split,
TensorFromScalar, TensorFromScalar,
get_scalar_constant_value, get_underlying_scalar_constant_value,
) )
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
...@@ -106,7 +106,7 @@ def jax_funcify_Join(op, **kwargs): ...@@ -106,7 +106,7 @@ def jax_funcify_Join(op, **kwargs):
def jax_funcify_Split(op: Split, node, **kwargs): def jax_funcify_Split(op: Split, node, **kwargs):
_, axis, splits = node.inputs _, axis, splits = node.inputs
try: try:
constant_axis = get_scalar_constant_value(axis) constant_axis = get_underlying_scalar_constant_value(axis)
except NotScalarConstantError: except NotScalarConstantError:
constant_axis = None constant_axis = None
warnings.warn( warnings.warn(
...@@ -116,7 +116,7 @@ def jax_funcify_Split(op: Split, node, **kwargs): ...@@ -116,7 +116,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
try: try:
constant_splits = np.array( constant_splits = np.array(
[ [
get_scalar_constant_value(splits[i]) get_underlying_scalar_constant_value(splits[i])
for i in range(get_vector_length(splits)) for i in range(get_vector_length(splits))
] ]
) )
......
...@@ -12,7 +12,7 @@ from pytensor.graph.replace import clone_replace ...@@ -12,7 +12,7 @@ from pytensor.graph.replace import clone_replace
from pytensor.graph.utils import MissingInputError, TestValueError from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until from pytensor.scan.utils import expand_empty, safe_new, until
from pytensor.tensor.basic import get_scalar_constant_value from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import minimum from pytensor.tensor.math import minimum
from pytensor.tensor.shape import shape_padleft, unbroadcast from pytensor.tensor.shape import shape_padleft, unbroadcast
...@@ -147,7 +147,7 @@ def isNaN_or_Inf_or_None(x): ...@@ -147,7 +147,7 @@ def isNaN_or_Inf_or_None(x):
isStr = False isStr = False
if not isNaN and not isInf: if not isNaN and not isInf:
try: try:
val = get_scalar_constant_value(x) val = get_underlying_scalar_constant_value(x)
isInf = np.isinf(val) isInf = np.isinf(val)
isNaN = np.isnan(val) isNaN = np.isnan(val)
except Exception: except Exception:
...@@ -476,7 +476,7 @@ def scan( ...@@ -476,7 +476,7 @@ def scan(
n_fixed_steps = int(n_steps) n_fixed_steps = int(n_steps)
else: else:
try: try:
n_fixed_steps = at.get_scalar_constant_value(n_steps) n_fixed_steps = at.get_underlying_scalar_constant_value(n_steps)
except NotScalarConstantError: except NotScalarConstantError:
n_fixed_steps = None n_fixed_steps = None
......
...@@ -49,7 +49,11 @@ from pytensor.scan.utils import ( ...@@ -49,7 +49,11 @@ from pytensor.scan.utils import (
safe_new, safe_new,
scan_can_remove_outs, scan_can_remove_outs,
) )
from pytensor.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, dot, maximum, minimum from pytensor.tensor.math import Dot, dot, maximum, minimum
...@@ -1956,13 +1960,13 @@ class ScanMerge(GraphRewriter): ...@@ -1956,13 +1960,13 @@ class ScanMerge(GraphRewriter):
nsteps = node.inputs[0] nsteps = node.inputs[0]
try: try:
nsteps = int(get_scalar_constant_value(nsteps)) nsteps = int(get_underlying_scalar_constant_value(nsteps))
except NotScalarConstantError: except NotScalarConstantError:
pass pass
rep_nsteps = rep.inputs[0] rep_nsteps = rep.inputs[0]
try: try:
rep_nsteps = int(get_scalar_constant_value(rep_nsteps)) rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
except NotScalarConstantError: except NotScalarConstantError:
pass pass
......
...@@ -256,6 +256,26 @@ _scalar_constant_value_elemwise_ops = ( ...@@ -256,6 +256,26 @@ _scalar_constant_value_elemwise_ops = (
def get_scalar_constant_value( def get_scalar_constant_value(
v, elemwise=True, only_process_constants=False, max_recur=10
):
"""
Checks whether 'v' is a scalar (ndim = 0).
If 'v' is a scalar then this function fetches the underlying constant by calling
'get_underlying_scalar_constant_value()'.
If 'v' is not a scalar, it raises a NotScalarConstantError.
"""
if isinstance(v, (Variable, np.ndarray)):
if v.ndim != 0:
raise NotScalarConstantError()
return get_underlying_scalar_constant_value(
v, elemwise, only_process_constants, max_recur
)
def get_underlying_scalar_constant_value(
orig_v, elemwise=True, only_process_constants=False, max_recur=10 orig_v, elemwise=True, only_process_constants=False, max_recur=10
): ):
"""Return the constant scalar(0-D) value underlying variable `v`. """Return the constant scalar(0-D) value underlying variable `v`.
...@@ -358,7 +378,7 @@ def get_scalar_constant_value( ...@@ -358,7 +378,7 @@ def get_scalar_constant_value(
elif isinstance(v.owner.op, CheckAndRaise): elif isinstance(v.owner.op, CheckAndRaise):
# check if all conditions are constant and true # check if all conditions are constant and true
conds = [ conds = [
get_scalar_constant_value(c, max_recur=max_recur) get_underlying_scalar_constant_value(c, max_recur=max_recur)
for c in v.owner.inputs[1:] for c in v.owner.inputs[1:]
] ]
if builtins.all(0 == c.ndim and c != 0 for c in conds): if builtins.all(0 == c.ndim and c != 0 for c in conds):
...@@ -372,7 +392,7 @@ def get_scalar_constant_value( ...@@ -372,7 +392,7 @@ def get_scalar_constant_value(
continue continue
if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops): if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops):
const = [ const = [
get_scalar_constant_value(i, max_recur=max_recur) get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs for i in v.owner.inputs
] ]
ret = [[None]] ret = [[None]]
...@@ -391,7 +411,7 @@ def get_scalar_constant_value( ...@@ -391,7 +411,7 @@ def get_scalar_constant_value(
v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops
): ):
const = [ const = [
get_scalar_constant_value(i, max_recur=max_recur) get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs for i in v.owner.inputs
] ]
ret = [[None]] ret = [[None]]
...@@ -437,7 +457,7 @@ def get_scalar_constant_value( ...@@ -437,7 +457,7 @@ def get_scalar_constant_value(
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, Type): if isinstance(idx, Type):
idx = get_scalar_constant_value( idx = get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
try: try:
...@@ -471,14 +491,14 @@ def get_scalar_constant_value( ...@@ -471,14 +491,14 @@ def get_scalar_constant_value(
): ):
idx = v.owner.op.idx_list[0] idx = v.owner.op.idx_list[0]
if isinstance(idx, Type): if isinstance(idx, Type):
idx = get_scalar_constant_value( idx = get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur v.owner.inputs[1], max_recur=max_recur
) )
# Python 2.4 does not support indexing with numpy.integer # Python 2.4 does not support indexing with numpy.integer
# So we cast it. # So we cast it.
idx = int(idx) idx = int(idx)
ret = v.owner.inputs[0].owner.inputs[idx] ret = v.owner.inputs[0].owner.inputs[idx]
ret = get_scalar_constant_value(ret, max_recur=max_recur) ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur)
# MakeVector can cast implicitly its input in some case. # MakeVector can cast implicitly its input in some case.
return _asarray(ret, dtype=v.type.dtype) return _asarray(ret, dtype=v.type.dtype)
...@@ -493,7 +513,7 @@ def get_scalar_constant_value( ...@@ -493,7 +513,7 @@ def get_scalar_constant_value(
idx_list = op.idx_list idx_list = op.idx_list
idx = idx_list[0] idx = idx_list[0]
if isinstance(idx, Type): if isinstance(idx, Type):
idx = get_scalar_constant_value( idx = get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur owner.inputs[1], max_recur=max_recur
) )
grandparent = leftmost_parent.owner.inputs[0] grandparent = leftmost_parent.owner.inputs[0]
...@@ -508,7 +528,7 @@ def get_scalar_constant_value( ...@@ -508,7 +528,7 @@ def get_scalar_constant_value(
if not (idx < ndim): if not (idx < ndim):
msg = ( msg = (
"get_scalar_constant_value detected " "get_underlying_scalar_constant_value detected "
f"deterministic IndexError: x.shape[{int(idx)}] " f"deterministic IndexError: x.shape[{int(idx)}] "
f"when x.ndim={int(ndim)}." f"when x.ndim={int(ndim)}."
) )
...@@ -1570,7 +1590,7 @@ pprint.assign(alloc, printing.FunctionPrinter(["alloc"])) ...@@ -1570,7 +1590,7 @@ pprint.assign(alloc, printing.FunctionPrinter(["alloc"]))
@_get_vector_length.register(Alloc) @_get_vector_length.register(Alloc)
def _get_vector_length_Alloc(var_inst, var): def _get_vector_length_Alloc(var_inst, var):
try: try:
return get_scalar_constant_value(var.owner.inputs[1]) return get_underlying_scalar_constant_value(var.owner.inputs[1])
except NotScalarConstantError: except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined") raise ValueError(f"Length of {var} cannot be determined")
...@@ -1821,17 +1841,17 @@ default = Default() ...@@ -1821,17 +1841,17 @@ default = Default()
def extract_constant(x, elemwise=True, only_process_constants=False): def extract_constant(x, elemwise=True, only_process_constants=False):
""" """
This function is basically a call to tensor.get_scalar_constant_value. This function is basically a call to tensor.get_underlying_scalar_constant_value.
The main difference is the behaviour in case of failure. While The main difference is the behaviour in case of failure. While
get_scalar_constant_value raises an TypeError, this function returns x, get_underlying_scalar_constant_value raises an TypeError, this function returns x,
as a tensor if possible. If x is a ScalarVariable from a as a tensor if possible. If x is a ScalarVariable from a
scalar_from_tensor, we remove the conversion. If x is just a scalar_from_tensor, we remove the conversion. If x is just a
ScalarVariable, we convert it to a tensor with tensor_from_scalar. ScalarVariable, we convert it to a tensor with tensor_from_scalar.
""" """
try: try:
x = get_scalar_constant_value(x, elemwise, only_process_constants) x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
if isinstance(x, aes.ScalarVariable) or isinstance( if isinstance(x, aes.ScalarVariable) or isinstance(
...@@ -2201,7 +2221,7 @@ class Join(COp): ...@@ -2201,7 +2221,7 @@ class Join(COp):
if not isinstance(axis, int): if not isinstance(axis, int):
try: try:
axis = int(get_scalar_constant_value(axis)) axis = int(get_underlying_scalar_constant_value(axis))
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -2450,7 +2470,7 @@ pprint.assign(Join, printing.FunctionPrinter(["join"])) ...@@ -2450,7 +2470,7 @@ pprint.assign(Join, printing.FunctionPrinter(["join"]))
def _get_vector_length_Join(op, var): def _get_vector_length_Join(op, var):
axis, *arrays = var.owner.inputs axis, *arrays = var.owner.inputs
try: try:
axis = get_scalar_constant_value(axis) axis = get_underlying_scalar_constant_value(axis)
assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays) assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays)
return builtins.sum(get_vector_length(a) for a in arrays) return builtins.sum(get_vector_length(a) for a in arrays)
except NotScalarConstantError: except NotScalarConstantError:
...@@ -2862,7 +2882,7 @@ class ARange(Op): ...@@ -2862,7 +2882,7 @@ class ARange(Op):
def is_constant_value(var, value): def is_constant_value(var, value):
try: try:
v = get_scalar_constant_value(var) v = get_underlying_scalar_constant_value(var)
return np.all(v == value) return np.all(v == value)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -3774,7 +3794,7 @@ class Choose(Op): ...@@ -3774,7 +3794,7 @@ class Choose(Op):
static_out_shape = () static_out_shape = ()
for s in out_shape: for s in out_shape:
try: try:
s_val = pytensor.get_scalar_constant_value(s) s_val = pytensor.get_underlying_scalar_constant(s)
except (NotScalarConstantError, AttributeError): except (NotScalarConstantError, AttributeError):
s_val = None s_val = None
...@@ -4095,6 +4115,7 @@ __all__ = [ ...@@ -4095,6 +4115,7 @@ __all__ = [
"scalar_from_tensor", "scalar_from_tensor",
"tensor_from_scalar", "tensor_from_scalar",
"get_scalar_constant_value", "get_scalar_constant_value",
"get_underlying_scalar_constant_value",
"constant", "constant",
"as_tensor_variable", "as_tensor_variable",
"as_tensor", "as_tensor",
......
...@@ -1834,7 +1834,7 @@ def local_gemm_to_ger(fgraph, node): ...@@ -1834,7 +1834,7 @@ def local_gemm_to_ger(fgraph, node):
xv = x.dimshuffle(0) xv = x.dimshuffle(0)
yv = y.dimshuffle(1) yv = y.dimshuffle(1)
try: try:
bval = at.get_scalar_constant_value(b) bval = at.get_underlying_scalar_constant_value(b)
except NotScalarConstantError: except NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling # b isn't a constant, GEMM is doing useful pre-scaling
return return
......
...@@ -24,7 +24,10 @@ from pytensor.configdefaults import config ...@@ -24,7 +24,10 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.tensor.basic import as_tensor_variable, get_scalar_constant_value from pytensor.tensor.basic import (
as_tensor_variable,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.var import TensorConstant, TensorVariable from pytensor.tensor.var import TensorConstant, TensorVariable
...@@ -495,8 +498,8 @@ def check_conv_gradinputs_shape( ...@@ -495,8 +498,8 @@ def check_conv_gradinputs_shape(
if given is None or computed is None: if given is None or computed is None:
return True return True
try: try:
given = get_scalar_constant_value(given) given = get_underlying_scalar_constant_value(given)
computed = get_scalar_constant_value(computed) computed = get_underlying_scalar_constant_value(computed)
return int(given) == int(computed) return int(given) == int(computed)
except NotScalarConstantError: except NotScalarConstantError:
# no answer possible, accept for now # no answer possible, accept for now
...@@ -532,7 +535,7 @@ def assert_conv_shape(shape): ...@@ -532,7 +535,7 @@ def assert_conv_shape(shape):
out_shape = [] out_shape = []
for i, n in enumerate(shape): for i, n in enumerate(shape):
try: try:
const_n = get_scalar_constant_value(n) const_n = get_underlying_scalar_constant_value(n)
if i < 2: if i < 2:
if const_n < 0: if const_n < 0:
raise ValueError( raise ValueError(
...@@ -2200,7 +2203,9 @@ class BaseAbstractConv(Op): ...@@ -2200,7 +2203,9 @@ class BaseAbstractConv(Op):
if imshp_i is not None: if imshp_i is not None:
# Components of imshp should be constant or ints # Components of imshp should be constant or ints
try: try:
get_scalar_constant_value(imshp_i, only_process_constants=True) get_underlying_scalar_constant_value(
imshp_i, only_process_constants=True
)
except NotScalarConstantError: except NotScalarConstantError:
raise ValueError( raise ValueError(
"imshp should be None or a tuple of constant int values" "imshp should be None or a tuple of constant int values"
...@@ -2213,7 +2218,9 @@ class BaseAbstractConv(Op): ...@@ -2213,7 +2218,9 @@ class BaseAbstractConv(Op):
if kshp_i is not None: if kshp_i is not None:
# Components of kshp should be constant or ints # Components of kshp should be constant or ints
try: try:
get_scalar_constant_value(kshp_i, only_process_constants=True) get_underlying_scalar_constant_value(
kshp_i, only_process_constants=True
)
except NotScalarConstantError: except NotScalarConstantError:
raise ValueError( raise ValueError(
"kshp should be None or a tuple of constant int values" "kshp should be None or a tuple of constant int values"
......
...@@ -759,7 +759,7 @@ class Elemwise(OpenMPOp): ...@@ -759,7 +759,7 @@ class Elemwise(OpenMPOp):
ufunc = self.ufunc ufunc = self.ufunc
elif not hasattr(node.tag, "ufunc"): elif not hasattr(node.tag, "ufunc"):
# It happen that make_thunk isn't called, like in # It happen that make_thunk isn't called, like in
# get_scalar_constant_value # get_underlying_scalar_constant_value
self.prepare_node(node, None, None, "py") self.prepare_node(node, None, None, "py")
# prepare_node will add ufunc to self or the tag # prepare_node will add ufunc to self or the tag
# depending if we can reuse it or not. So we need to # depending if we can reuse it or not. So we need to
......
...@@ -4,7 +4,7 @@ class ShapeError(Exception): ...@@ -4,7 +4,7 @@ class ShapeError(Exception):
class NotScalarConstantError(Exception): class NotScalarConstantError(Exception):
""" """
Raised by get_scalar_constant_value if called on something that is Raised by get_underlying_scalar_constant_value if called on something that is
not a scalar constant. not a scalar constant.
""" """
......
...@@ -671,7 +671,7 @@ class Repeat(Op): ...@@ -671,7 +671,7 @@ class Repeat(Op):
out_shape = [None] out_shape = [None]
else: else:
try: try:
const_reps = at.get_scalar_constant_value(repeats) const_reps = at.get_underlying_scalar_constant_value(repeats)
except NotScalarConstantError: except NotScalarConstantError:
const_reps = None const_reps = None
if const_reps == 1: if const_reps == 1:
......
...@@ -12,7 +12,7 @@ from pytensor.scalar import ScalarVariable ...@@ -12,7 +12,7 @@ from pytensor.scalar import ScalarVariable
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
as_tensor_variable, as_tensor_variable,
constant, constant,
get_scalar_constant_value, get_underlying_scalar_constant_value,
get_vector_length, get_vector_length,
infer_static_shape, infer_static_shape,
) )
...@@ -277,7 +277,7 @@ class RandomVariable(Op): ...@@ -277,7 +277,7 @@ class RandomVariable(Op):
try: try:
size_len = get_vector_length(size) size_len = get_vector_length(size)
except ValueError: except ValueError:
size_len = get_scalar_constant_value(size_shape[0]) size_len = get_underlying_scalar_constant_value(size_shape[0])
size = tuple(size[n] for n in range(size_len)) size = tuple(size[n] for n in range(size_len))
......
...@@ -32,7 +32,7 @@ from pytensor.tensor.basic import ( ...@@ -32,7 +32,7 @@ from pytensor.tensor.basic import (
cast, cast,
extract_constant, extract_constant,
fill, fill,
get_scalar_constant_value, get_underlying_scalar_constant_value,
join, join,
ones_like, ones_like,
switch, switch,
...@@ -802,7 +802,7 @@ def local_remove_useless_assert(fgraph, node): ...@@ -802,7 +802,7 @@ def local_remove_useless_assert(fgraph, node):
n_conds = len(node.inputs[1:]) n_conds = len(node.inputs[1:])
for c in node.inputs[1:]: for c in node.inputs[1:]:
try: try:
const = get_scalar_constant_value(c) const = get_underlying_scalar_constant_value(c)
if 0 != const.ndim or const == 0: if 0 != const.ndim or const == 0:
# Should we raise an error here? How to be sure it # Should we raise an error here? How to be sure it
...@@ -895,7 +895,7 @@ def local_join_empty(fgraph, node): ...@@ -895,7 +895,7 @@ def local_join_empty(fgraph, node):
return return
new_inputs = [] new_inputs = []
try: try:
join_idx = get_scalar_constant_value( join_idx = get_underlying_scalar_constant_value(
node.inputs[0], only_process_constants=True node.inputs[0], only_process_constants=True
) )
except NotScalarConstantError: except NotScalarConstantError:
......
...@@ -22,7 +22,12 @@ from pytensor.graph.rewriting.basic import ( ...@@ -22,7 +22,12 @@ from pytensor.graph.rewriting.basic import (
) )
from pytensor.graph.rewriting.db import SequenceDB from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value from pytensor.tensor.basic import (
MakeVector,
alloc,
cast,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
...@@ -495,7 +500,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): ...@@ -495,7 +500,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
else: else:
try: try:
# works only for scalars # works only for scalars
cval_i = get_scalar_constant_value( cval_i = get_underlying_scalar_constant_value(
i, only_process_constants=True i, only_process_constants=True
) )
if all(i.broadcastable): if all(i.broadcastable):
......
...@@ -31,7 +31,7 @@ from pytensor.tensor.basic import ( ...@@ -31,7 +31,7 @@ from pytensor.tensor.basic import (
constant, constant,
extract_constant, extract_constant,
fill, fill,
get_scalar_constant_value, get_underlying_scalar_constant_value,
ones_like, ones_like,
switch, switch,
zeros_like, zeros_like,
...@@ -112,7 +112,7 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): ...@@ -112,7 +112,7 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
nonconsts = [] nonconsts = []
for i in inputs: for i in inputs:
try: try:
v = get_scalar_constant_value( v = get_underlying_scalar_constant_value(
i, elemwise=elemwise, only_process_constants=only_process_constants i, elemwise=elemwise, only_process_constants=only_process_constants
) )
consts.append(v) consts.append(v)
...@@ -165,13 +165,13 @@ def local_0_dot_x(fgraph, node): ...@@ -165,13 +165,13 @@ def local_0_dot_x(fgraph, node):
y = node.inputs[1] y = node.inputs[1]
replace = False replace = False
try: try:
if get_scalar_constant_value(x, only_process_constants=True) == 0: if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0:
replace = True replace = True
except NotScalarConstantError: except NotScalarConstantError:
pass pass
try: try:
if get_scalar_constant_value(y, only_process_constants=True) == 0: if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0:
replace = True replace = True
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -585,7 +585,7 @@ def local_mul_switch_sink(fgraph, node): ...@@ -585,7 +585,7 @@ def local_mul_switch_sink(fgraph, node):
switch_node = i.owner switch_node = i.owner
try: try:
if ( if (
get_scalar_constant_value( get_underlying_scalar_constant_value(
switch_node.inputs[1], only_process_constants=True switch_node.inputs[1], only_process_constants=True
) )
== 0.0 == 0.0
...@@ -613,7 +613,7 @@ def local_mul_switch_sink(fgraph, node): ...@@ -613,7 +613,7 @@ def local_mul_switch_sink(fgraph, node):
pass pass
try: try:
if ( if (
get_scalar_constant_value( get_underlying_scalar_constant_value(
switch_node.inputs[2], only_process_constants=True switch_node.inputs[2], only_process_constants=True
) )
== 0.0 == 0.0
...@@ -665,7 +665,7 @@ def local_div_switch_sink(fgraph, node): ...@@ -665,7 +665,7 @@ def local_div_switch_sink(fgraph, node):
switch_node = node.inputs[0].owner switch_node = node.inputs[0].owner
try: try:
if ( if (
get_scalar_constant_value( get_underlying_scalar_constant_value(
switch_node.inputs[1], only_process_constants=True switch_node.inputs[1], only_process_constants=True
) )
== 0.0 == 0.0
...@@ -691,7 +691,7 @@ def local_div_switch_sink(fgraph, node): ...@@ -691,7 +691,7 @@ def local_div_switch_sink(fgraph, node):
pass pass
try: try:
if ( if (
get_scalar_constant_value( get_underlying_scalar_constant_value(
switch_node.inputs[2], only_process_constants=True switch_node.inputs[2], only_process_constants=True
) )
== 0.0 == 0.0
...@@ -1493,7 +1493,9 @@ def local_useless_elemwise_comparison(fgraph, node): ...@@ -1493,7 +1493,9 @@ def local_useless_elemwise_comparison(fgraph, node):
and investigate(node.inputs[0].owner) and investigate(node.inputs[0].owner)
): ):
try: try:
cst = get_scalar_constant_value(node.inputs[1], only_process_constants=True) cst = get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True
)
res = zeros_like(node.inputs[0], dtype=dtype, opt=True) res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
...@@ -1733,7 +1735,7 @@ def local_reduce_join(fgraph, node): ...@@ -1733,7 +1735,7 @@ def local_reduce_join(fgraph, node):
# We add the new check late to don't add extra warning. # We add the new check late to don't add extra warning.
try: try:
join_axis = get_scalar_constant_value( join_axis = get_underlying_scalar_constant_value(
join_node.inputs[0], only_process_constants=True join_node.inputs[0], only_process_constants=True
) )
...@@ -1816,7 +1818,9 @@ def local_opt_alloc(fgraph, node): ...@@ -1816,7 +1818,9 @@ def local_opt_alloc(fgraph, node):
inp = node_inps.owner.inputs[0] inp = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:] shapes = node_inps.owner.inputs[1:]
try: try:
val = get_scalar_constant_value(inp, only_process_constants=True) val = get_underlying_scalar_constant_value(
inp, only_process_constants=True
)
assert val.size == 1 assert val.size == 1
val = val.reshape(1)[0] val = val.reshape(1)[0]
# check which type of op # check which type of op
...@@ -1948,7 +1952,7 @@ def local_mul_zero(fgraph, node): ...@@ -1948,7 +1952,7 @@ def local_mul_zero(fgraph, node):
for i in node.inputs: for i in node.inputs:
try: try:
value = get_scalar_constant_value(i) value = get_underlying_scalar_constant_value(i)
except NotScalarConstantError: except NotScalarConstantError:
continue continue
# print 'MUL by value', value, node.inputs # print 'MUL by value', value, node.inputs
...@@ -2230,7 +2234,7 @@ def local_add_specialize(fgraph, node): ...@@ -2230,7 +2234,7 @@ def local_add_specialize(fgraph, node):
new_inputs = [] new_inputs = []
for inp in node.inputs: for inp in node.inputs:
try: try:
y = get_scalar_constant_value(inp) y = get_underlying_scalar_constant_value(inp)
except NotScalarConstantError: except NotScalarConstantError:
y = inp y = inp
if np.all(y == 0.0): if np.all(y == 0.0):
...@@ -2329,7 +2333,9 @@ def local_abs_merge(fgraph, node): ...@@ -2329,7 +2333,9 @@ def local_abs_merge(fgraph, node):
inputs.append(i.owner.inputs[0]) inputs.append(i.owner.inputs[0])
elif isinstance(i, Constant): elif isinstance(i, Constant):
try: try:
const = get_scalar_constant_value(i, only_process_constants=True) const = get_underlying_scalar_constant_value(
i, only_process_constants=True
)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
if not (const >= 0).all(): if not (const >= 0).all():
...@@ -2878,7 +2884,7 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2878,7 +2884,7 @@ def local_grad_log_erfc_neg(fgraph, node):
mul_neg = mul(*mul_inputs) mul_neg = mul(*mul_inputs)
try: try:
cst2 = get_scalar_constant_value( cst2 = get_underlying_scalar_constant_value(
mul_neg.owner.inputs[0], only_process_constants=True mul_neg.owner.inputs[0], only_process_constants=True
) )
except NotScalarConstantError: except NotScalarConstantError:
...@@ -2912,7 +2918,7 @@ def local_grad_log_erfc_neg(fgraph, node): ...@@ -2912,7 +2918,7 @@ def local_grad_log_erfc_neg(fgraph, node):
x = erfc_x x = erfc_x
try: try:
cst = get_scalar_constant_value( cst = get_underlying_scalar_constant_value(
erfc_x.owner.inputs[0], only_process_constants=True erfc_x.owner.inputs[0], only_process_constants=True
) )
except NotScalarConstantError: except NotScalarConstantError:
...@@ -2979,7 +2985,7 @@ def _is_1(expr): ...@@ -2979,7 +2985,7 @@ def _is_1(expr):
""" """
try: try:
v = get_scalar_constant_value(expr) v = get_underlying_scalar_constant_value(expr)
return np.allclose(v, 1) return np.allclose(v, 1)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
...@@ -3147,7 +3153,7 @@ def is_neg(var): ...@@ -3147,7 +3153,7 @@ def is_neg(var):
if var_node.op == mul and len(var_node.inputs) >= 2: if var_node.op == mul and len(var_node.inputs) >= 2:
for idx, mul_input in enumerate(var_node.inputs): for idx, mul_input in enumerate(var_node.inputs):
try: try:
constant = get_scalar_constant_value(mul_input) constant = get_underlying_scalar_constant_value(mul_input)
is_minus_1 = np.allclose(constant, -1) is_minus_1 = np.allclose(constant, -1)
except NotScalarConstantError: except NotScalarConstantError:
is_minus_1 = False is_minus_1 = False
......
...@@ -24,7 +24,7 @@ from pytensor.tensor.basic import ( ...@@ -24,7 +24,7 @@ from pytensor.tensor.basic import (
cast, cast,
constant, constant,
extract_constant, extract_constant,
get_scalar_constant_value, get_underlying_scalar_constant_value,
stack, stack,
) )
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
...@@ -226,7 +226,7 @@ class ShapeFeature(Feature): ...@@ -226,7 +226,7 @@ class ShapeFeature(Feature):
# Do not call make_node for test_value # Do not call make_node for test_value
s = Shape_i(i)(r) s = Shape_i(i)(r)
try: try:
s = get_scalar_constant_value(s) s = get_underlying_scalar_constant_value(s)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
return s return s
...@@ -310,7 +310,7 @@ class ShapeFeature(Feature): ...@@ -310,7 +310,7 @@ class ShapeFeature(Feature):
assert len(idx) == 1 assert len(idx) == 1
idx = idx[0] idx = idx[0]
try: try:
i = get_scalar_constant_value(idx) i = get_underlying_scalar_constant_value(idx)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
else: else:
......
...@@ -25,7 +25,7 @@ from pytensor.tensor.basic import ( ...@@ -25,7 +25,7 @@ from pytensor.tensor.basic import (
cast, cast,
concatenate, concatenate,
extract_constant, extract_constant,
get_scalar_constant_value, get_underlying_scalar_constant_value,
switch, switch,
) )
from pytensor.tensor.elemwise import Elemwise from pytensor.tensor.elemwise import Elemwise
...@@ -756,7 +756,9 @@ def local_subtensor_make_vector(fgraph, node): ...@@ -756,7 +756,9 @@ def local_subtensor_make_vector(fgraph, node):
elif isinstance(idx, Variable): elif isinstance(idx, Variable):
if idx.ndim == 0: if idx.ndim == 0:
try: try:
v = get_scalar_constant_value(idx, only_process_constants=True) v = get_underlying_scalar_constant_value(
idx, only_process_constants=True
)
try: try:
ret = [x.owner.inputs[v]] ret = [x.owner.inputs[v]]
except IndexError: except IndexError:
...@@ -808,7 +810,7 @@ def local_useless_inc_subtensor(fgraph, node): ...@@ -808,7 +810,7 @@ def local_useless_inc_subtensor(fgraph, node):
# This is an increment operation, so the array being incremented must # This is an increment operation, so the array being incremented must
# consist of all zeros in order for the entire operation to be useless # consist of all zeros in order for the entire operation to be useless
try: try:
c = get_scalar_constant_value(x) c = get_underlying_scalar_constant_value(x)
if c != 0: if c != 0:
return return
except NotScalarConstantError: except NotScalarConstantError:
...@@ -927,7 +929,7 @@ def local_useless_subtensor(fgraph, node): ...@@ -927,7 +929,7 @@ def local_useless_subtensor(fgraph, node):
if isinstance(idx.stop, (int, np.integer)): if isinstance(idx.stop, (int, np.integer)):
length_pos_data = sys.maxsize length_pos_data = sys.maxsize
try: try:
length_pos_data = get_scalar_constant_value( length_pos_data = get_underlying_scalar_constant_value(
length_pos, only_process_constants=True length_pos, only_process_constants=True
) )
except NotScalarConstantError: except NotScalarConstantError:
...@@ -992,7 +994,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node): ...@@ -992,7 +994,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node):
# get length of the indexed tensor along the first axis # get length of the indexed tensor along the first axis
try: try:
length = get_scalar_constant_value( length = get_underlying_scalar_constant_value(
shape_of[node.inputs[0]][0], only_process_constants=True shape_of[node.inputs[0]][0], only_process_constants=True
) )
except NotScalarConstantError: except NotScalarConstantError:
...@@ -1329,7 +1331,7 @@ def local_incsubtensor_of_zeros(fgraph, node): ...@@ -1329,7 +1331,7 @@ def local_incsubtensor_of_zeros(fgraph, node):
try: try:
# Don't use only_process_constants=True. We need to # Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape. # investigate Alloc of 0s but with non constant shape.
if get_scalar_constant_value(y, elemwise=False) == 0: if get_underlying_scalar_constant_value(y, elemwise=False) == 0:
# No need to copy over the stacktrace, # No need to copy over the stacktrace,
# because x should already have a stacktrace # because x should already have a stacktrace
return [x] return [x]
...@@ -1375,12 +1377,12 @@ def local_setsubtensor_of_constants(fgraph, node): ...@@ -1375,12 +1377,12 @@ def local_setsubtensor_of_constants(fgraph, node):
# Don't use only_process_constants=True. We need to # Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape. # investigate Alloc of 0s but with non constant shape.
try: try:
replace_x = get_scalar_constant_value(x, elemwise=False) replace_x = get_underlying_scalar_constant_value(x, elemwise=False)
except NotScalarConstantError: except NotScalarConstantError:
return return
try: try:
replace_y = get_scalar_constant_value(y, elemwise=False) replace_y = get_underlying_scalar_constant_value(y, elemwise=False)
except NotScalarConstantError: except NotScalarConstantError:
return return
...@@ -1668,7 +1670,7 @@ def local_join_subtensors(fgraph, node): ...@@ -1668,7 +1670,7 @@ def local_join_subtensors(fgraph, node):
axis, tensors = node.inputs[0], node.inputs[1:] axis, tensors = node.inputs[0], node.inputs[1:]
try: try:
axis = get_scalar_constant_value(axis) axis = get_underlying_scalar_constant_value(axis)
except NotScalarConstantError: except NotScalarConstantError:
return return
...@@ -1729,7 +1731,12 @@ def local_join_subtensors(fgraph, node): ...@@ -1729,7 +1731,12 @@ def local_join_subtensors(fgraph, node):
if step is None: if step is None:
continue continue
try: try:
if get_scalar_constant_value(step, only_process_constants=True) != 1: if (
get_underlying_scalar_constant_value(
step, only_process_constants=True
)
!= 1
):
return None return None
except NotScalarConstantError: except NotScalarConstantError:
return None return None
......
...@@ -397,7 +397,7 @@ class SpecifyShape(COp): ...@@ -397,7 +397,7 @@ class SpecifyShape(COp):
_f16_ok = True _f16_ok = True
def make_node(self, x, *shape): def make_node(self, x, *shape):
from pytensor.tensor.basic import get_scalar_constant_value from pytensor.tensor.basic import get_underlying_scalar_constant_value
x = at.as_tensor_variable(x) x = at.as_tensor_variable(x)
...@@ -426,7 +426,7 @@ class SpecifyShape(COp): ...@@ -426,7 +426,7 @@ class SpecifyShape(COp):
type_shape[i] = xts type_shape[i] = xts
else: else:
try: try:
type_s = get_scalar_constant_value(s) type_s = get_underlying_scalar_constant_value(s)
if type_s is not None: if type_s is not None:
type_shape[i] = int(type_s) type_shape[i] = int(type_s)
except NotScalarConstantError: except NotScalarConstantError:
...@@ -457,9 +457,9 @@ class SpecifyShape(COp): ...@@ -457,9 +457,9 @@ class SpecifyShape(COp):
for dim in range(node.inputs[0].type.ndim): for dim in range(node.inputs[0].type.ndim):
s = shape[dim] s = shape[dim]
try: try:
s = at.get_scalar_constant_value(s) s = at.get_underlying_scalar_constant_value(s)
# We assume that `None` shapes are always retrieved by # We assume that `None` shapes are always retrieved by
# `get_scalar_constant_value`, and only in that case do we default to # `get_underlying_scalar_constant_value`, and only in that case do we default to
# the shape of the input variable # the shape of the input variable
if s is None: if s is None:
s = xshape[dim] s = xshape[dim]
...@@ -581,7 +581,7 @@ def specify_shape( ...@@ -581,7 +581,7 @@ def specify_shape(
@_get_vector_length.register(SpecifyShape) @_get_vector_length.register(SpecifyShape)
def _get_vector_length_SpecifyShape(op, var): def _get_vector_length_SpecifyShape(op, var):
try: try:
return at.get_scalar_constant_value(var.owner.inputs[1]).item() return at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item()
except NotScalarConstantError: except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined") raise ValueError(f"Length of {var} cannot be determined")
...@@ -635,7 +635,7 @@ class Reshape(COp): ...@@ -635,7 +635,7 @@ class Reshape(COp):
y = shp_list[index] y = shp_list[index]
y = at.as_tensor_variable(y) y = at.as_tensor_variable(y)
try: try:
s_val = at.get_scalar_constant_value(y).item() s_val = at.get_underlying_scalar_constant_value(y).item()
if s_val >= 0: if s_val >= 0:
out_shape[index] = s_val out_shape[index] = s_val
except NotScalarConstantError: except NotScalarConstantError:
......
...@@ -20,7 +20,7 @@ from pytensor.misc.safe_asarray import _asarray ...@@ -20,7 +20,7 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint, set_precedence from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor.basic import alloc, get_scalar_constant_value from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import ( from pytensor.tensor.exceptions import (
AdvancedIndexingError, AdvancedIndexingError,
...@@ -656,7 +656,7 @@ def get_constant_idx( ...@@ -656,7 +656,7 @@ def get_constant_idx(
return slice(conv(val.start), conv(val.stop), conv(val.step)) return slice(conv(val.start), conv(val.stop), conv(val.step))
else: else:
try: try:
return get_scalar_constant_value( return get_underlying_scalar_constant_value(
val, val,
only_process_constants=only_process_constants, only_process_constants=only_process_constants,
elemwise=elemwise, elemwise=elemwise,
...@@ -733,7 +733,7 @@ class Subtensor(COp): ...@@ -733,7 +733,7 @@ class Subtensor(COp):
if s == 1: if s == 1:
start = p.start start = p.start
try: try:
start = get_scalar_constant_value(start) start = get_underlying_scalar_constant_value(start)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
if start is None or start == 0: if start is None or start == 0:
...@@ -2808,17 +2808,17 @@ def _get_vector_length_Subtensor(op, var): ...@@ -2808,17 +2808,17 @@ def _get_vector_length_Subtensor(op, var):
start = ( start = (
None None
if indices[0].start is None if indices[0].start is None
else get_scalar_constant_value(indices[0].start) else get_underlying_scalar_constant_value(indices[0].start)
) )
stop = ( stop = (
None None
if indices[0].stop is None if indices[0].stop is None
else get_scalar_constant_value(indices[0].stop) else get_underlying_scalar_constant_value(indices[0].stop)
) )
step = ( step = (
None None
if indices[0].step is None if indices[0].step is None
else get_scalar_constant_value(indices[0].step) else get_underlying_scalar_constant_value(indices[0].step)
) )
if start == stop: if start == stop:
......
...@@ -756,8 +756,8 @@ class _tensor_py_operators: ...@@ -756,8 +756,8 @@ class _tensor_py_operators:
# This value is set so that PyTensor arrays will trump NumPy operators. # This value is set so that PyTensor arrays will trump NumPy operators.
__array_priority__ = 1000 __array_priority__ = 1000
def get_scalar_constant_value(self): def get_underlying_scalar_constant(self):
return at.basic.get_scalar_constant_value(self) return at.basic.get_underlying_scalar_constant_value(self)
def zeros_like(model, dtype=None): def zeros_like(model, dtype=None):
return at.basic.zeros_like(model, dtype=dtype) return at.basic.zeros_like(model, dtype=dtype)
......
...@@ -1043,7 +1043,7 @@ class TestConversion: ...@@ -1043,7 +1043,7 @@ class TestConversion:
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
at.get_scalar_constant_value(s, only_process_constants=True) at.get_underlying_scalar_constant_value(s, only_process_constants=True)
# TODO: # TODO:
# def test_sparse_as_tensor_variable(self): # def test_sparse_as_tensor_variable(self):
......
...@@ -52,6 +52,7 @@ from pytensor.tensor.basic import ( ...@@ -52,6 +52,7 @@ from pytensor.tensor.basic import (
flatten, flatten,
full_like, full_like,
get_scalar_constant_value, get_scalar_constant_value,
get_underlying_scalar_constant_value,
get_vector_length, get_vector_length,
horizontal_stack, horizontal_stack,
identity_like, identity_like,
...@@ -3263,52 +3264,52 @@ def test_dimshuffle_duplicate(): ...@@ -3263,52 +3264,52 @@ def test_dimshuffle_duplicate():
DimShuffle((False,), (0, 0))(x) DimShuffle((False,), (0, 0))(x)
class TestGetScalarConstantValue: class TestGetUnderlyingScalarConstantValue:
def test_basic(self): def test_basic(self):
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(aes.int64()) get_underlying_scalar_constant_value(aes.int64())
res = get_scalar_constant_value(at.as_tensor(10)) res = get_underlying_scalar_constant_value(at.as_tensor(10))
assert res == 10 assert res == 10
assert isinstance(res, np.ndarray) assert isinstance(res, np.ndarray)
res = get_scalar_constant_value(np.array(10)) res = get_underlying_scalar_constant_value(np.array(10))
assert res == 10 assert res == 10
assert isinstance(res, np.ndarray) assert isinstance(res, np.ndarray)
a = at.stack([1, 2, 3]) a = at.stack([1, 2, 3])
assert get_scalar_constant_value(a[0]) == 1 assert get_underlying_scalar_constant_value(a[0]) == 1
assert get_scalar_constant_value(a[1]) == 2 assert get_underlying_scalar_constant_value(a[1]) == 2
assert get_scalar_constant_value(a[2]) == 3 assert get_underlying_scalar_constant_value(a[2]) == 3
b = iscalar() b = iscalar()
a = at.stack([b, 2, 3]) a = at.stack([b, 2, 3])
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a[0]) get_underlying_scalar_constant_value(a[0])
assert get_scalar_constant_value(a[1]) == 2 assert get_underlying_scalar_constant_value(a[1]) == 2
assert get_scalar_constant_value(a[2]) == 3 assert get_underlying_scalar_constant_value(a[2]) == 3
# For now get_scalar_constant_value goes through only MakeVector and Join of # For now get_underlying_scalar_constant_value goes through only MakeVector and Join of
# scalars. # scalars.
v = ivector() v = ivector()
a = at.stack([v, [2], [3]]) a = at.stack([v, [2], [3]])
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a[0]) get_underlying_scalar_constant_value(a[0])
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a[1]) get_underlying_scalar_constant_value(a[1])
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a[2]) get_underlying_scalar_constant_value(a[2])
# Test the case SubTensor(Shape(v)) when the dimensions # Test the case SubTensor(Shape(v)) when the dimensions
# is broadcastable. # is broadcastable.
v = row() v = row()
assert get_scalar_constant_value(v.shape[0]) == 1 assert get_underlying_scalar_constant_value(v.shape[0]) == 1
res = at.get_scalar_constant_value(at.as_tensor([10, 20]).shape[0]) res = at.get_underlying_scalar_constant_value(at.as_tensor([10, 20]).shape[0])
assert isinstance(res, np.ndarray) assert isinstance(res, np.ndarray)
assert 2 == res assert 2 == res
res = at.get_scalar_constant_value( res = at.get_underlying_scalar_constant_value(
9 + at.as_tensor([1.0]).shape[0], 9 + at.as_tensor([1.0]).shape[0],
elemwise=True, elemwise=True,
only_process_constants=False, only_process_constants=False,
...@@ -3320,63 +3321,63 @@ class TestGetScalarConstantValue: ...@@ -3320,63 +3321,63 @@ class TestGetScalarConstantValue:
@pytest.mark.xfail(reason="Incomplete implementation") @pytest.mark.xfail(reason="Incomplete implementation")
def test_DimShufle(self): def test_DimShufle(self):
a = as_tensor_variable(1.0)[None][0] a = as_tensor_variable(1.0)[None][0]
assert get_scalar_constant_value(a) == 1 assert get_underlying_scalar_constant_value(a) == 1
def test_subtensor_of_constant(self): def test_subtensor_of_constant(self):
c = constant(random(5)) c = constant(random(5))
for i in range(c.value.shape[0]): for i in range(c.value.shape[0]):
assert get_scalar_constant_value(c[i]) == c.value[i] assert get_underlying_scalar_constant_value(c[i]) == c.value[i]
c = constant(random(5, 5)) c = constant(random(5, 5))
for i in range(c.value.shape[0]): for i in range(c.value.shape[0]):
for j in range(c.value.shape[1]): for j in range(c.value.shape[1]):
assert get_scalar_constant_value(c[i, j]) == c.value[i, j] assert get_underlying_scalar_constant_value(c[i, j]) == c.value[i, j]
def test_numpy_array(self): def test_numpy_array(self):
# Regression test for crash when called on a numpy array. # Regression test for crash when called on a numpy array.
assert get_scalar_constant_value(np.array(3)) == 3 assert get_underlying_scalar_constant_value(np.array(3)) == 3
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(np.array([0, 1])) get_underlying_scalar_constant_value(np.array([0, 1]))
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(np.array([])) get_underlying_scalar_constant_value(np.array([]))
def test_make_vector(self): def test_make_vector(self):
mv = make_vector(1, 2, 3) mv = make_vector(1, 2, 3)
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(mv) get_underlying_scalar_constant_value(mv)
assert get_scalar_constant_value(mv[0]) == 1 assert get_underlying_scalar_constant_value(mv[0]) == 1
assert get_scalar_constant_value(mv[1]) == 2 assert get_underlying_scalar_constant_value(mv[1]) == 2
assert get_scalar_constant_value(mv[2]) == 3 assert get_underlying_scalar_constant_value(mv[2]) == 3
assert get_scalar_constant_value(mv[np.int32(0)]) == 1 assert get_underlying_scalar_constant_value(mv[np.int32(0)]) == 1
assert get_scalar_constant_value(mv[np.int64(1)]) == 2 assert get_underlying_scalar_constant_value(mv[np.int64(1)]) == 2
assert get_scalar_constant_value(mv[np.uint(2)]) == 3 assert get_underlying_scalar_constant_value(mv[np.uint(2)]) == 3
t = aes.ScalarType("int64") t = aes.ScalarType("int64")
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(mv[t()]) get_underlying_scalar_constant_value(mv[t()])
def test_shape_i(self): def test_shape_i(self):
c = constant(np.random.random((3, 4))) c = constant(np.random.random((3, 4)))
s = Shape_i(0)(c) s = Shape_i(0)(c)
assert get_scalar_constant_value(s) == 3 assert get_underlying_scalar_constant_value(s) == 3
s = Shape_i(1)(c) s = Shape_i(1)(c)
assert get_scalar_constant_value(s) == 4 assert get_underlying_scalar_constant_value(s) == 4
d = pytensor.shared(np.random.standard_normal((1, 1)), shape=(1, 1)) d = pytensor.shared(np.random.standard_normal((1, 1)), shape=(1, 1))
f = ScalarFromTensor()(Shape_i(0)(d)) f = ScalarFromTensor()(Shape_i(0)(d))
assert get_scalar_constant_value(f) == 1 assert get_underlying_scalar_constant_value(f) == 1
def test_elemwise(self): def test_elemwise(self):
# We test only for a few elemwise, the list of all supported # We test only for a few elemwise, the list of all supported
# elemwise are in the fct. # elemwise are in the fct.
c = constant(np.random.random()) c = constant(np.random.random())
s = c + 1 s = c + 1
assert np.allclose(get_scalar_constant_value(s), c.data + 1) assert np.allclose(get_underlying_scalar_constant_value(s), c.data + 1)
s = c - 1 s = c - 1
assert np.allclose(get_scalar_constant_value(s), c.data - 1) assert np.allclose(get_underlying_scalar_constant_value(s), c.data - 1)
s = c * 1.2 s = c * 1.2
assert np.allclose(get_scalar_constant_value(s), c.data * 1.2) assert np.allclose(get_underlying_scalar_constant_value(s), c.data * 1.2)
s = c < 0.5 s = c < 0.5
assert np.allclose(get_scalar_constant_value(s), int(c.data < 0.5)) assert np.allclose(get_underlying_scalar_constant_value(s), int(c.data < 0.5))
s = at.second(c, 0.4) s = at.second(c, 0.4)
assert np.allclose(get_scalar_constant_value(s), 0.4) assert np.allclose(get_underlying_scalar_constant_value(s), 0.4)
def test_assert(self): def test_assert(self):
# Make sure we still get the constant value if it is wrapped in # Make sure we still get the constant value if it is wrapped in
...@@ -3386,25 +3387,25 @@ class TestGetScalarConstantValue: ...@@ -3386,25 +3387,25 @@ class TestGetScalarConstantValue:
# condition is always True # condition is always True
a = Assert()(c, c > 1) a = Assert()(c, c > 1)
assert get_scalar_constant_value(a) == 2 assert get_underlying_scalar_constant_value(a) == 2
with config.change_flags(compute_test_value="off"): with config.change_flags(compute_test_value="off"):
# condition is always False # condition is always False
a = Assert()(c, c > 2) a = Assert()(c, c > 2)
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a) get_underlying_scalar_constant_value(a)
# condition is not constant # condition is not constant
a = Assert()(c, c > x) a = Assert()(c, c > x)
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a) get_underlying_scalar_constant_value(a)
def test_second(self): def test_second(self):
# Second should apply when the value is constant but not the shape # Second should apply when the value is constant but not the shape
c = constant(np.random.random()) c = constant(np.random.random())
shp = vector() shp = vector()
s = at.second(shp, c) s = at.second(shp, c)
assert get_scalar_constant_value(s) == c.data assert get_underlying_scalar_constant_value(s) == c.data
def test_copy(self): def test_copy(self):
# Make sure we do not return the internal storage of a constant, # Make sure we do not return the internal storage of a constant,
...@@ -3418,17 +3419,27 @@ class TestGetScalarConstantValue: ...@@ -3418,17 +3419,27 @@ class TestGetScalarConstantValue:
@pytest.mark.parametrize("only_process_constants", (True, False)) @pytest.mark.parametrize("only_process_constants", (True, False))
def test_None_and_NoneConst(self, only_process_constants): def test_None_and_NoneConst(self, only_process_constants):
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value( get_underlying_scalar_constant_value(
None, only_process_constants=only_process_constants None, only_process_constants=only_process_constants
) )
assert ( assert (
get_scalar_constant_value( get_underlying_scalar_constant_value(
NoneConst, only_process_constants=only_process_constants NoneConst, only_process_constants=only_process_constants
) )
is None is None
) )
@pytest.mark.parametrize(
["valid_inp", "invalid_inp"],
((np.array(4), np.zeros(5)), (at.constant(4), at.constant(3, ndim=1))),
)
def test_get_scalar_constant_value(valid_inp, invalid_inp):
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(invalid_inp)
assert get_scalar_constant_value(valid_inp) == 4
def test_complex_mod_failure(): def test_complex_mod_failure():
# Make sure % fails on complex numbers. # Make sure % fails on complex numbers.
x = vector(dtype="complex64") x = vector(dtype="complex64")
......
...@@ -823,8 +823,8 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -823,8 +823,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert len(res_shape) == 1 assert len(res_shape) == 1
assert len(res_shape[0]) == 2 assert len(res_shape[0]) == 2
assert pytensor.get_scalar_constant_value(res_shape[0][0]) == 1 assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1
assert pytensor.get_scalar_constant_value(res_shape[0][1]) == 1 assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1
def test_multi_output(self): def test_multi_output(self):
class CustomElemwise(Elemwise): class CustomElemwise(Elemwise):
......
...@@ -27,7 +27,7 @@ from pytensor.tensor.basic import ( ...@@ -27,7 +27,7 @@ from pytensor.tensor.basic import (
as_tensor_variable, as_tensor_variable,
constant, constant,
eye, eye,
get_scalar_constant_value, get_underlying_scalar_constant_value,
switch, switch,
) )
from pytensor.tensor.elemwise import CAReduce, Elemwise from pytensor.tensor.elemwise import CAReduce, Elemwise
...@@ -894,7 +894,7 @@ class TestMaxAndArgmax: ...@@ -894,7 +894,7 @@ class TestMaxAndArgmax:
x = matrix() x = matrix()
cost = argmax(x, axis=0).sum() cost = argmax(x, axis=0).sum()
gx = grad(cost, x) gx = grad(cost, x)
val = get_scalar_constant_value(gx) val = get_underlying_scalar_constant_value(gx)
assert val == 0.0 assert val == 0.0
def test_grad(self): def test_grad(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论