Unverified 提交 fda240fd authored 作者: Shreyas Singh's avatar Shreyas Singh 提交者: GitHub

`get_scalar_constant_value` now raises for non-scalar inputs (#248)

* Rename old get_scalar_constant_value to get_underlying_scalar_constant
上级 feccc417
......@@ -577,7 +577,7 @@ them perfectly, but a `dscalar` otherwise.
.. method:: round(mode="half_away_from_zero")
:noindex:
.. method:: trace()
.. method:: get_scalar_constant_value()
.. method:: get_underlying_scalar_constant_value()
.. method:: zeros_like(model, dtype=None)
All the above methods are equivalent to NumPy for PyTensor on the current tensor.
......
......@@ -137,7 +137,7 @@ from pytensor.updates import OrderedUpdates
# isort: on
def get_scalar_constant_value(v):
def get_underlying_scalar_constant(v):
"""Return the constant scalar (i.e. 0-D) value underlying variable `v`.
If `v` is the output of dim-shuffles, fills, allocs, cast, etc.
......@@ -153,8 +153,8 @@ def get_scalar_constant_value(v):
if sparse and isinstance(v.type, sparse.SparseTensorType):
if v.owner is not None and isinstance(v.owner.op, sparse.CSM):
data = v.owner.inputs[0]
return tensor.get_scalar_constant_value(data)
return tensor.get_scalar_constant_value(v)
return tensor.get_underlying_scalar_constant_value(data)
return tensor.get_underlying_scalar_constant_value(v)
# isort: off
......
......@@ -1325,7 +1325,7 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
f" {i}. Since this input is only connected "
"to integer-valued outputs, it should "
"evaluate to zeros, but it evaluates to"
f"{pytensor.get_scalar_constant_value(term)}."
f"{pytensor.get_underlying_scalar_constant(term)}."
)
raise ValueError(msg)
......@@ -2086,7 +2086,7 @@ def _is_zero(x):
no_constant_value = True
try:
constant_value = pytensor.get_scalar_constant_value(x)
constant_value = pytensor.get_underlying_scalar_constant(x)
no_constant_value = False
except pytensor.tensor.exceptions.NotScalarConstantError:
pass
......
......@@ -18,7 +18,7 @@ from pytensor.tensor.basic import (
ScalarFromTensor,
Split,
TensorFromScalar,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
......@@ -106,7 +106,7 @@ def jax_funcify_Join(op, **kwargs):
def jax_funcify_Split(op: Split, node, **kwargs):
_, axis, splits = node.inputs
try:
constant_axis = get_scalar_constant_value(axis)
constant_axis = get_underlying_scalar_constant_value(axis)
except NotScalarConstantError:
constant_axis = None
warnings.warn(
......@@ -116,7 +116,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
try:
constant_splits = np.array(
[
get_scalar_constant_value(splits[i])
get_underlying_scalar_constant_value(splits[i])
for i in range(get_vector_length(splits))
]
)
......
......@@ -12,7 +12,7 @@ from pytensor.graph.replace import clone_replace
from pytensor.graph.utils import MissingInputError, TestValueError
from pytensor.scan.op import Scan, ScanInfo
from pytensor.scan.utils import expand_empty, safe_new, until
from pytensor.tensor.basic import get_scalar_constant_value
from pytensor.tensor.basic import get_underlying_scalar_constant_value
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import minimum
from pytensor.tensor.shape import shape_padleft, unbroadcast
......@@ -147,7 +147,7 @@ def isNaN_or_Inf_or_None(x):
isStr = False
if not isNaN and not isInf:
try:
val = get_scalar_constant_value(x)
val = get_underlying_scalar_constant_value(x)
isInf = np.isinf(val)
isNaN = np.isnan(val)
except Exception:
......@@ -476,7 +476,7 @@ def scan(
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = at.get_scalar_constant_value(n_steps)
n_fixed_steps = at.get_underlying_scalar_constant_value(n_steps)
except NotScalarConstantError:
n_fixed_steps = None
......
......@@ -49,7 +49,11 @@ from pytensor.scan.utils import (
safe_new,
scan_can_remove_outs,
)
from pytensor.tensor.basic import Alloc, AllocEmpty, get_scalar_constant_value
from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import Dot, dot, maximum, minimum
......@@ -1956,13 +1960,13 @@ class ScanMerge(GraphRewriter):
nsteps = node.inputs[0]
try:
nsteps = int(get_scalar_constant_value(nsteps))
nsteps = int(get_underlying_scalar_constant_value(nsteps))
except NotScalarConstantError:
pass
rep_nsteps = rep.inputs[0]
try:
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
except NotScalarConstantError:
pass
......
......@@ -256,6 +256,26 @@ _scalar_constant_value_elemwise_ops = (
def get_scalar_constant_value(
v, elemwise=True, only_process_constants=False, max_recur=10
):
"""
Checks whether 'v' is a scalar (ndim = 0).
If 'v' is a scalar then this function fetches the underlying constant by calling
'get_underlying_scalar_constant_value()'.
If 'v' is not a scalar, it raises a NotScalarConstantError.
"""
if isinstance(v, (Variable, np.ndarray)):
if v.ndim != 0:
raise NotScalarConstantError()
return get_underlying_scalar_constant_value(
v, elemwise, only_process_constants, max_recur
)
def get_underlying_scalar_constant_value(
orig_v, elemwise=True, only_process_constants=False, max_recur=10
):
"""Return the constant scalar(0-D) value underlying variable `v`.
......@@ -358,7 +378,7 @@ def get_scalar_constant_value(
elif isinstance(v.owner.op, CheckAndRaise):
# check if all conditions are constant and true
conds = [
get_scalar_constant_value(c, max_recur=max_recur)
get_underlying_scalar_constant_value(c, max_recur=max_recur)
for c in v.owner.inputs[1:]
]
if builtins.all(0 == c.ndim and c != 0 for c in conds):
......@@ -372,7 +392,7 @@ def get_scalar_constant_value(
continue
if isinstance(v.owner.op, _scalar_constant_value_elemwise_ops):
const = [
get_scalar_constant_value(i, max_recur=max_recur)
get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs
]
ret = [[None]]
......@@ -391,7 +411,7 @@ def get_scalar_constant_value(
v.owner.op.scalar_op, _scalar_constant_value_elemwise_ops
):
const = [
get_scalar_constant_value(i, max_recur=max_recur)
get_underlying_scalar_constant_value(i, max_recur=max_recur)
for i in v.owner.inputs
]
ret = [[None]]
......@@ -437,7 +457,7 @@ def get_scalar_constant_value(
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
idx = get_scalar_constant_value(
idx = get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
try:
......@@ -471,14 +491,14 @@ def get_scalar_constant_value(
):
idx = v.owner.op.idx_list[0]
if isinstance(idx, Type):
idx = get_scalar_constant_value(
idx = get_underlying_scalar_constant_value(
v.owner.inputs[1], max_recur=max_recur
)
# Python 2.4 does not support indexing with numpy.integer
# So we cast it.
idx = int(idx)
ret = v.owner.inputs[0].owner.inputs[idx]
ret = get_scalar_constant_value(ret, max_recur=max_recur)
ret = get_underlying_scalar_constant_value(ret, max_recur=max_recur)
# MakeVector can cast implicitly its input in some case.
return _asarray(ret, dtype=v.type.dtype)
......@@ -493,7 +513,7 @@ def get_scalar_constant_value(
idx_list = op.idx_list
idx = idx_list[0]
if isinstance(idx, Type):
idx = get_scalar_constant_value(
idx = get_underlying_scalar_constant_value(
owner.inputs[1], max_recur=max_recur
)
grandparent = leftmost_parent.owner.inputs[0]
......@@ -508,7 +528,7 @@ def get_scalar_constant_value(
if not (idx < ndim):
msg = (
"get_scalar_constant_value detected "
"get_underlying_scalar_constant_value detected "
f"deterministic IndexError: x.shape[{int(idx)}] "
f"when x.ndim={int(ndim)}."
)
......@@ -1570,7 +1590,7 @@ pprint.assign(alloc, printing.FunctionPrinter(["alloc"]))
@_get_vector_length.register(Alloc)
def _get_vector_length_Alloc(var_inst, var):
try:
return get_scalar_constant_value(var.owner.inputs[1])
return get_underlying_scalar_constant_value(var.owner.inputs[1])
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")
......@@ -1821,17 +1841,17 @@ default = Default()
def extract_constant(x, elemwise=True, only_process_constants=False):
"""
This function is basically a call to tensor.get_scalar_constant_value.
This function is basically a call to tensor.get_underlying_scalar_constant_value.
The main difference is the behaviour in case of failure. While
get_scalar_constant_value raises an TypeError, this function returns x,
get_underlying_scalar_constant_value raises an TypeError, this function returns x,
as a tensor if possible. If x is a ScalarVariable from a
scalar_from_tensor, we remove the conversion. If x is just a
ScalarVariable, we convert it to a tensor with tensor_from_scalar.
"""
try:
x = get_scalar_constant_value(x, elemwise, only_process_constants)
x = get_underlying_scalar_constant_value(x, elemwise, only_process_constants)
except NotScalarConstantError:
pass
if isinstance(x, aes.ScalarVariable) or isinstance(
......@@ -2201,7 +2221,7 @@ class Join(COp):
if not isinstance(axis, int):
try:
axis = int(get_scalar_constant_value(axis))
axis = int(get_underlying_scalar_constant_value(axis))
except NotScalarConstantError:
pass
......@@ -2450,7 +2470,7 @@ pprint.assign(Join, printing.FunctionPrinter(["join"]))
def _get_vector_length_Join(op, var):
axis, *arrays = var.owner.inputs
try:
axis = get_scalar_constant_value(axis)
axis = get_underlying_scalar_constant_value(axis)
assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays)
return builtins.sum(get_vector_length(a) for a in arrays)
except NotScalarConstantError:
......@@ -2862,7 +2882,7 @@ class ARange(Op):
def is_constant_value(var, value):
try:
v = get_scalar_constant_value(var)
v = get_underlying_scalar_constant_value(var)
return np.all(v == value)
except NotScalarConstantError:
pass
......@@ -3774,7 +3794,7 @@ class Choose(Op):
static_out_shape = ()
for s in out_shape:
try:
s_val = pytensor.get_scalar_constant_value(s)
s_val = pytensor.get_underlying_scalar_constant(s)
except (NotScalarConstantError, AttributeError):
s_val = None
......@@ -4095,6 +4115,7 @@ __all__ = [
"scalar_from_tensor",
"tensor_from_scalar",
"get_scalar_constant_value",
"get_underlying_scalar_constant_value",
"constant",
"as_tensor_variable",
"as_tensor",
......
......@@ -1834,7 +1834,7 @@ def local_gemm_to_ger(fgraph, node):
xv = x.dimshuffle(0)
yv = y.dimshuffle(1)
try:
bval = at.get_scalar_constant_value(b)
bval = at.get_underlying_scalar_constant_value(b)
except NotScalarConstantError:
# b isn't a constant, GEMM is doing useful pre-scaling
return
......
......@@ -24,7 +24,10 @@ from pytensor.configdefaults import config
from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op
from pytensor.raise_op import Assert
from pytensor.tensor.basic import as_tensor_variable, get_scalar_constant_value
from pytensor.tensor.basic import (
as_tensor_variable,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.var import TensorConstant, TensorVariable
......@@ -495,8 +498,8 @@ def check_conv_gradinputs_shape(
if given is None or computed is None:
return True
try:
given = get_scalar_constant_value(given)
computed = get_scalar_constant_value(computed)
given = get_underlying_scalar_constant_value(given)
computed = get_underlying_scalar_constant_value(computed)
return int(given) == int(computed)
except NotScalarConstantError:
# no answer possible, accept for now
......@@ -532,7 +535,7 @@ def assert_conv_shape(shape):
out_shape = []
for i, n in enumerate(shape):
try:
const_n = get_scalar_constant_value(n)
const_n = get_underlying_scalar_constant_value(n)
if i < 2:
if const_n < 0:
raise ValueError(
......@@ -2200,7 +2203,9 @@ class BaseAbstractConv(Op):
if imshp_i is not None:
# Components of imshp should be constant or ints
try:
get_scalar_constant_value(imshp_i, only_process_constants=True)
get_underlying_scalar_constant_value(
imshp_i, only_process_constants=True
)
except NotScalarConstantError:
raise ValueError(
"imshp should be None or a tuple of constant int values"
......@@ -2213,7 +2218,9 @@ class BaseAbstractConv(Op):
if kshp_i is not None:
# Components of kshp should be constant or ints
try:
get_scalar_constant_value(kshp_i, only_process_constants=True)
get_underlying_scalar_constant_value(
kshp_i, only_process_constants=True
)
except NotScalarConstantError:
raise ValueError(
"kshp should be None or a tuple of constant int values"
......
......@@ -759,7 +759,7 @@ class Elemwise(OpenMPOp):
ufunc = self.ufunc
elif not hasattr(node.tag, "ufunc"):
# It happen that make_thunk isn't called, like in
# get_scalar_constant_value
# get_underlying_scalar_constant_value
self.prepare_node(node, None, None, "py")
# prepare_node will add ufunc to self or the tag
# depending if we can reuse it or not. So we need to
......
......@@ -4,7 +4,7 @@ class ShapeError(Exception):
class NotScalarConstantError(Exception):
"""
Raised by get_scalar_constant_value if called on something that is
Raised by get_underlying_scalar_constant_value if called on something that is
not a scalar constant.
"""
......
......@@ -671,7 +671,7 @@ class Repeat(Op):
out_shape = [None]
else:
try:
const_reps = at.get_scalar_constant_value(repeats)
const_reps = at.get_underlying_scalar_constant_value(repeats)
except NotScalarConstantError:
const_reps = None
if const_reps == 1:
......
......@@ -12,7 +12,7 @@ from pytensor.scalar import ScalarVariable
from pytensor.tensor.basic import (
as_tensor_variable,
constant,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
get_vector_length,
infer_static_shape,
)
......@@ -277,7 +277,7 @@ class RandomVariable(Op):
try:
size_len = get_vector_length(size)
except ValueError:
size_len = get_scalar_constant_value(size_shape[0])
size_len = get_underlying_scalar_constant_value(size_shape[0])
size = tuple(size[n] for n in range(size_len))
......
......@@ -32,7 +32,7 @@ from pytensor.tensor.basic import (
cast,
extract_constant,
fill,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
join,
ones_like,
switch,
......@@ -802,7 +802,7 @@ def local_remove_useless_assert(fgraph, node):
n_conds = len(node.inputs[1:])
for c in node.inputs[1:]:
try:
const = get_scalar_constant_value(c)
const = get_underlying_scalar_constant_value(c)
if 0 != const.ndim or const == 0:
# Should we raise an error here? How to be sure it
......@@ -895,7 +895,7 @@ def local_join_empty(fgraph, node):
return
new_inputs = []
try:
join_idx = get_scalar_constant_value(
join_idx = get_underlying_scalar_constant_value(
node.inputs[0], only_process_constants=True
)
except NotScalarConstantError:
......
......@@ -22,7 +22,12 @@ from pytensor.graph.rewriting.basic import (
)
from pytensor.graph.rewriting.db import SequenceDB
from pytensor.graph.utils import InconsistencyError, MethodNotDefined
from pytensor.tensor.basic import MakeVector, alloc, cast, get_scalar_constant_value
from pytensor.tensor.basic import (
MakeVector,
alloc,
cast,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.rewriting.basic import register_canonicalize, register_specialize
......@@ -495,7 +500,7 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
else:
try:
# works only for scalars
cval_i = get_scalar_constant_value(
cval_i = get_underlying_scalar_constant_value(
i, only_process_constants=True
)
if all(i.broadcastable):
......
......@@ -31,7 +31,7 @@ from pytensor.tensor.basic import (
constant,
extract_constant,
fill,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
ones_like,
switch,
zeros_like,
......@@ -112,7 +112,7 @@ def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
nonconsts = []
for i in inputs:
try:
v = get_scalar_constant_value(
v = get_underlying_scalar_constant_value(
i, elemwise=elemwise, only_process_constants=only_process_constants
)
consts.append(v)
......@@ -165,13 +165,13 @@ def local_0_dot_x(fgraph, node):
y = node.inputs[1]
replace = False
try:
if get_scalar_constant_value(x, only_process_constants=True) == 0:
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
try:
if get_scalar_constant_value(y, only_process_constants=True) == 0:
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
......@@ -585,7 +585,7 @@ def local_mul_switch_sink(fgraph, node):
switch_node = i.owner
try:
if (
get_scalar_constant_value(
get_underlying_scalar_constant_value(
switch_node.inputs[1], only_process_constants=True
)
== 0.0
......@@ -613,7 +613,7 @@ def local_mul_switch_sink(fgraph, node):
pass
try:
if (
get_scalar_constant_value(
get_underlying_scalar_constant_value(
switch_node.inputs[2], only_process_constants=True
)
== 0.0
......@@ -665,7 +665,7 @@ def local_div_switch_sink(fgraph, node):
switch_node = node.inputs[0].owner
try:
if (
get_scalar_constant_value(
get_underlying_scalar_constant_value(
switch_node.inputs[1], only_process_constants=True
)
== 0.0
......@@ -691,7 +691,7 @@ def local_div_switch_sink(fgraph, node):
pass
try:
if (
get_scalar_constant_value(
get_underlying_scalar_constant_value(
switch_node.inputs[2], only_process_constants=True
)
== 0.0
......@@ -1493,7 +1493,9 @@ def local_useless_elemwise_comparison(fgraph, node):
and investigate(node.inputs[0].owner)
):
try:
cst = get_scalar_constant_value(node.inputs[1], only_process_constants=True)
cst = get_underlying_scalar_constant_value(
node.inputs[1], only_process_constants=True
)
res = zeros_like(node.inputs[0], dtype=dtype, opt=True)
......@@ -1733,7 +1735,7 @@ def local_reduce_join(fgraph, node):
# We add the new check late to don't add extra warning.
try:
join_axis = get_scalar_constant_value(
join_axis = get_underlying_scalar_constant_value(
join_node.inputs[0], only_process_constants=True
)
......@@ -1816,7 +1818,9 @@ def local_opt_alloc(fgraph, node):
inp = node_inps.owner.inputs[0]
shapes = node_inps.owner.inputs[1:]
try:
val = get_scalar_constant_value(inp, only_process_constants=True)
val = get_underlying_scalar_constant_value(
inp, only_process_constants=True
)
assert val.size == 1
val = val.reshape(1)[0]
# check which type of op
......@@ -1948,7 +1952,7 @@ def local_mul_zero(fgraph, node):
for i in node.inputs:
try:
value = get_scalar_constant_value(i)
value = get_underlying_scalar_constant_value(i)
except NotScalarConstantError:
continue
# print 'MUL by value', value, node.inputs
......@@ -2230,7 +2234,7 @@ def local_add_specialize(fgraph, node):
new_inputs = []
for inp in node.inputs:
try:
y = get_scalar_constant_value(inp)
y = get_underlying_scalar_constant_value(inp)
except NotScalarConstantError:
y = inp
if np.all(y == 0.0):
......@@ -2329,7 +2333,9 @@ def local_abs_merge(fgraph, node):
inputs.append(i.owner.inputs[0])
elif isinstance(i, Constant):
try:
const = get_scalar_constant_value(i, only_process_constants=True)
const = get_underlying_scalar_constant_value(
i, only_process_constants=True
)
except NotScalarConstantError:
return False
if not (const >= 0).all():
......@@ -2878,7 +2884,7 @@ def local_grad_log_erfc_neg(fgraph, node):
mul_neg = mul(*mul_inputs)
try:
cst2 = get_scalar_constant_value(
cst2 = get_underlying_scalar_constant_value(
mul_neg.owner.inputs[0], only_process_constants=True
)
except NotScalarConstantError:
......@@ -2912,7 +2918,7 @@ def local_grad_log_erfc_neg(fgraph, node):
x = erfc_x
try:
cst = get_scalar_constant_value(
cst = get_underlying_scalar_constant_value(
erfc_x.owner.inputs[0], only_process_constants=True
)
except NotScalarConstantError:
......@@ -2979,7 +2985,7 @@ def _is_1(expr):
"""
try:
v = get_scalar_constant_value(expr)
v = get_underlying_scalar_constant_value(expr)
return np.allclose(v, 1)
except NotScalarConstantError:
return False
......@@ -3147,7 +3153,7 @@ def is_neg(var):
if var_node.op == mul and len(var_node.inputs) >= 2:
for idx, mul_input in enumerate(var_node.inputs):
try:
constant = get_scalar_constant_value(mul_input)
constant = get_underlying_scalar_constant_value(mul_input)
is_minus_1 = np.allclose(constant, -1)
except NotScalarConstantError:
is_minus_1 = False
......
......@@ -24,7 +24,7 @@ from pytensor.tensor.basic import (
cast,
constant,
extract_constant,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
stack,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
......@@ -226,7 +226,7 @@ class ShapeFeature(Feature):
# Do not call make_node for test_value
s = Shape_i(i)(r)
try:
s = get_scalar_constant_value(s)
s = get_underlying_scalar_constant_value(s)
except NotScalarConstantError:
pass
return s
......@@ -310,7 +310,7 @@ class ShapeFeature(Feature):
assert len(idx) == 1
idx = idx[0]
try:
i = get_scalar_constant_value(idx)
i = get_underlying_scalar_constant_value(idx)
except NotScalarConstantError:
pass
else:
......
......@@ -25,7 +25,7 @@ from pytensor.tensor.basic import (
cast,
concatenate,
extract_constant,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
switch,
)
from pytensor.tensor.elemwise import Elemwise
......@@ -756,7 +756,9 @@ def local_subtensor_make_vector(fgraph, node):
elif isinstance(idx, Variable):
if idx.ndim == 0:
try:
v = get_scalar_constant_value(idx, only_process_constants=True)
v = get_underlying_scalar_constant_value(
idx, only_process_constants=True
)
try:
ret = [x.owner.inputs[v]]
except IndexError:
......@@ -808,7 +810,7 @@ def local_useless_inc_subtensor(fgraph, node):
# This is an increment operation, so the array being incremented must
# consist of all zeros in order for the entire operation to be useless
try:
c = get_scalar_constant_value(x)
c = get_underlying_scalar_constant_value(x)
if c != 0:
return
except NotScalarConstantError:
......@@ -927,7 +929,7 @@ def local_useless_subtensor(fgraph, node):
if isinstance(idx.stop, (int, np.integer)):
length_pos_data = sys.maxsize
try:
length_pos_data = get_scalar_constant_value(
length_pos_data = get_underlying_scalar_constant_value(
length_pos, only_process_constants=True
)
except NotScalarConstantError:
......@@ -992,7 +994,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node):
# get length of the indexed tensor along the first axis
try:
length = get_scalar_constant_value(
length = get_underlying_scalar_constant_value(
shape_of[node.inputs[0]][0], only_process_constants=True
)
except NotScalarConstantError:
......@@ -1329,7 +1331,7 @@ def local_incsubtensor_of_zeros(fgraph, node):
try:
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
if get_scalar_constant_value(y, elemwise=False) == 0:
if get_underlying_scalar_constant_value(y, elemwise=False) == 0:
# No need to copy over the stacktrace,
# because x should already have a stacktrace
return [x]
......@@ -1375,12 +1377,12 @@ def local_setsubtensor_of_constants(fgraph, node):
# Don't use only_process_constants=True. We need to
# investigate Alloc of 0s but with non constant shape.
try:
replace_x = get_scalar_constant_value(x, elemwise=False)
replace_x = get_underlying_scalar_constant_value(x, elemwise=False)
except NotScalarConstantError:
return
try:
replace_y = get_scalar_constant_value(y, elemwise=False)
replace_y = get_underlying_scalar_constant_value(y, elemwise=False)
except NotScalarConstantError:
return
......@@ -1668,7 +1670,7 @@ def local_join_subtensors(fgraph, node):
axis, tensors = node.inputs[0], node.inputs[1:]
try:
axis = get_scalar_constant_value(axis)
axis = get_underlying_scalar_constant_value(axis)
except NotScalarConstantError:
return
......@@ -1729,7 +1731,12 @@ def local_join_subtensors(fgraph, node):
if step is None:
continue
try:
if get_scalar_constant_value(step, only_process_constants=True) != 1:
if (
get_underlying_scalar_constant_value(
step, only_process_constants=True
)
!= 1
):
return None
except NotScalarConstantError:
return None
......
......@@ -397,7 +397,7 @@ class SpecifyShape(COp):
_f16_ok = True
def make_node(self, x, *shape):
from pytensor.tensor.basic import get_scalar_constant_value
from pytensor.tensor.basic import get_underlying_scalar_constant_value
x = at.as_tensor_variable(x)
......@@ -426,7 +426,7 @@ class SpecifyShape(COp):
type_shape[i] = xts
else:
try:
type_s = get_scalar_constant_value(s)
type_s = get_underlying_scalar_constant_value(s)
if type_s is not None:
type_shape[i] = int(type_s)
except NotScalarConstantError:
......@@ -457,9 +457,9 @@ class SpecifyShape(COp):
for dim in range(node.inputs[0].type.ndim):
s = shape[dim]
try:
s = at.get_scalar_constant_value(s)
s = at.get_underlying_scalar_constant_value(s)
# We assume that `None` shapes are always retrieved by
# `get_scalar_constant_value`, and only in that case do we default to
# `get_underlying_scalar_constant_value`, and only in that case do we default to
# the shape of the input variable
if s is None:
s = xshape[dim]
......@@ -581,7 +581,7 @@ def specify_shape(
@_get_vector_length.register(SpecifyShape)
def _get_vector_length_SpecifyShape(op, var):
try:
return at.get_scalar_constant_value(var.owner.inputs[1]).item()
return at.get_underlying_scalar_constant_value(var.owner.inputs[1]).item()
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")
......@@ -635,7 +635,7 @@ class Reshape(COp):
y = shp_list[index]
y = at.as_tensor_variable(y)
try:
s_val = at.get_scalar_constant_value(y).item()
s_val = at.get_underlying_scalar_constant_value(y).item()
if s_val >= 0:
out_shape[index] = s_val
except NotScalarConstantError:
......
......@@ -20,7 +20,7 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import Printer, pprint, set_precedence
from pytensor.scalar.basic import ScalarConstant
from pytensor.tensor import _get_vector_length, as_tensor_variable, get_vector_length
from pytensor.tensor.basic import alloc, get_scalar_constant_value
from pytensor.tensor.basic import alloc, get_underlying_scalar_constant_value
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.exceptions import (
AdvancedIndexingError,
......@@ -656,7 +656,7 @@ def get_constant_idx(
return slice(conv(val.start), conv(val.stop), conv(val.step))
else:
try:
return get_scalar_constant_value(
return get_underlying_scalar_constant_value(
val,
only_process_constants=only_process_constants,
elemwise=elemwise,
......@@ -733,7 +733,7 @@ class Subtensor(COp):
if s == 1:
start = p.start
try:
start = get_scalar_constant_value(start)
start = get_underlying_scalar_constant_value(start)
except NotScalarConstantError:
pass
if start is None or start == 0:
......@@ -2808,17 +2808,17 @@ def _get_vector_length_Subtensor(op, var):
start = (
None
if indices[0].start is None
else get_scalar_constant_value(indices[0].start)
else get_underlying_scalar_constant_value(indices[0].start)
)
stop = (
None
if indices[0].stop is None
else get_scalar_constant_value(indices[0].stop)
else get_underlying_scalar_constant_value(indices[0].stop)
)
step = (
None
if indices[0].step is None
else get_scalar_constant_value(indices[0].step)
else get_underlying_scalar_constant_value(indices[0].step)
)
if start == stop:
......
......@@ -756,8 +756,8 @@ class _tensor_py_operators:
# This value is set so that PyTensor arrays will trump NumPy operators.
__array_priority__ = 1000
def get_scalar_constant_value(self):
return at.basic.get_scalar_constant_value(self)
def get_underlying_scalar_constant(self):
return at.basic.get_underlying_scalar_constant_value(self)
def zeros_like(model, dtype=None):
return at.basic.zeros_like(model, dtype=dtype)
......
......@@ -1043,7 +1043,7 @@ class TestConversion:
from pytensor.tensor.exceptions import NotScalarConstantError
with pytest.raises(NotScalarConstantError):
at.get_scalar_constant_value(s, only_process_constants=True)
at.get_underlying_scalar_constant_value(s, only_process_constants=True)
# TODO:
# def test_sparse_as_tensor_variable(self):
......
......@@ -52,6 +52,7 @@ from pytensor.tensor.basic import (
flatten,
full_like,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
get_vector_length,
horizontal_stack,
identity_like,
......@@ -3263,52 +3264,52 @@ def test_dimshuffle_duplicate():
DimShuffle((False,), (0, 0))(x)
class TestGetScalarConstantValue:
class TestGetUnderlyingScalarConstantValue:
def test_basic(self):
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(aes.int64())
get_underlying_scalar_constant_value(aes.int64())
res = get_scalar_constant_value(at.as_tensor(10))
res = get_underlying_scalar_constant_value(at.as_tensor(10))
assert res == 10
assert isinstance(res, np.ndarray)
res = get_scalar_constant_value(np.array(10))
res = get_underlying_scalar_constant_value(np.array(10))
assert res == 10
assert isinstance(res, np.ndarray)
a = at.stack([1, 2, 3])
assert get_scalar_constant_value(a[0]) == 1
assert get_scalar_constant_value(a[1]) == 2
assert get_scalar_constant_value(a[2]) == 3
assert get_underlying_scalar_constant_value(a[0]) == 1
assert get_underlying_scalar_constant_value(a[1]) == 2
assert get_underlying_scalar_constant_value(a[2]) == 3
b = iscalar()
a = at.stack([b, 2, 3])
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a[0])
assert get_scalar_constant_value(a[1]) == 2
assert get_scalar_constant_value(a[2]) == 3
get_underlying_scalar_constant_value(a[0])
assert get_underlying_scalar_constant_value(a[1]) == 2
assert get_underlying_scalar_constant_value(a[2]) == 3
# For now get_scalar_constant_value goes through only MakeVector and Join of
# For now get_underlying_scalar_constant_value goes through only MakeVector and Join of
# scalars.
v = ivector()
a = at.stack([v, [2], [3]])
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a[0])
get_underlying_scalar_constant_value(a[0])
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a[1])
get_underlying_scalar_constant_value(a[1])
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a[2])
get_underlying_scalar_constant_value(a[2])
# Test the case SubTensor(Shape(v)) when the dimensions
# is broadcastable.
v = row()
assert get_scalar_constant_value(v.shape[0]) == 1
assert get_underlying_scalar_constant_value(v.shape[0]) == 1
res = at.get_scalar_constant_value(at.as_tensor([10, 20]).shape[0])
res = at.get_underlying_scalar_constant_value(at.as_tensor([10, 20]).shape[0])
assert isinstance(res, np.ndarray)
assert 2 == res
res = at.get_scalar_constant_value(
res = at.get_underlying_scalar_constant_value(
9 + at.as_tensor([1.0]).shape[0],
elemwise=True,
only_process_constants=False,
......@@ -3320,63 +3321,63 @@ class TestGetScalarConstantValue:
@pytest.mark.xfail(reason="Incomplete implementation")
def test_DimShufle(self):
a = as_tensor_variable(1.0)[None][0]
assert get_scalar_constant_value(a) == 1
assert get_underlying_scalar_constant_value(a) == 1
def test_subtensor_of_constant(self):
c = constant(random(5))
for i in range(c.value.shape[0]):
assert get_scalar_constant_value(c[i]) == c.value[i]
assert get_underlying_scalar_constant_value(c[i]) == c.value[i]
c = constant(random(5, 5))
for i in range(c.value.shape[0]):
for j in range(c.value.shape[1]):
assert get_scalar_constant_value(c[i, j]) == c.value[i, j]
assert get_underlying_scalar_constant_value(c[i, j]) == c.value[i, j]
def test_numpy_array(self):
# Regression test for crash when called on a numpy array.
assert get_scalar_constant_value(np.array(3)) == 3
assert get_underlying_scalar_constant_value(np.array(3)) == 3
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(np.array([0, 1]))
get_underlying_scalar_constant_value(np.array([0, 1]))
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(np.array([]))
get_underlying_scalar_constant_value(np.array([]))
def test_make_vector(self):
mv = make_vector(1, 2, 3)
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(mv)
assert get_scalar_constant_value(mv[0]) == 1
assert get_scalar_constant_value(mv[1]) == 2
assert get_scalar_constant_value(mv[2]) == 3
assert get_scalar_constant_value(mv[np.int32(0)]) == 1
assert get_scalar_constant_value(mv[np.int64(1)]) == 2
assert get_scalar_constant_value(mv[np.uint(2)]) == 3
get_underlying_scalar_constant_value(mv)
assert get_underlying_scalar_constant_value(mv[0]) == 1
assert get_underlying_scalar_constant_value(mv[1]) == 2
assert get_underlying_scalar_constant_value(mv[2]) == 3
assert get_underlying_scalar_constant_value(mv[np.int32(0)]) == 1
assert get_underlying_scalar_constant_value(mv[np.int64(1)]) == 2
assert get_underlying_scalar_constant_value(mv[np.uint(2)]) == 3
t = aes.ScalarType("int64")
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(mv[t()])
get_underlying_scalar_constant_value(mv[t()])
def test_shape_i(self):
c = constant(np.random.random((3, 4)))
s = Shape_i(0)(c)
assert get_scalar_constant_value(s) == 3
assert get_underlying_scalar_constant_value(s) == 3
s = Shape_i(1)(c)
assert get_scalar_constant_value(s) == 4
assert get_underlying_scalar_constant_value(s) == 4
d = pytensor.shared(np.random.standard_normal((1, 1)), shape=(1, 1))
f = ScalarFromTensor()(Shape_i(0)(d))
assert get_scalar_constant_value(f) == 1
assert get_underlying_scalar_constant_value(f) == 1
def test_elemwise(self):
# We test only for a few elemwise, the list of all supported
# elemwise are in the fct.
c = constant(np.random.random())
s = c + 1
assert np.allclose(get_scalar_constant_value(s), c.data + 1)
assert np.allclose(get_underlying_scalar_constant_value(s), c.data + 1)
s = c - 1
assert np.allclose(get_scalar_constant_value(s), c.data - 1)
assert np.allclose(get_underlying_scalar_constant_value(s), c.data - 1)
s = c * 1.2
assert np.allclose(get_scalar_constant_value(s), c.data * 1.2)
assert np.allclose(get_underlying_scalar_constant_value(s), c.data * 1.2)
s = c < 0.5
assert np.allclose(get_scalar_constant_value(s), int(c.data < 0.5))
assert np.allclose(get_underlying_scalar_constant_value(s), int(c.data < 0.5))
s = at.second(c, 0.4)
assert np.allclose(get_scalar_constant_value(s), 0.4)
assert np.allclose(get_underlying_scalar_constant_value(s), 0.4)
def test_assert(self):
# Make sure we still get the constant value if it is wrapped in
......@@ -3386,25 +3387,25 @@ class TestGetScalarConstantValue:
# condition is always True
a = Assert()(c, c > 1)
assert get_scalar_constant_value(a) == 2
assert get_underlying_scalar_constant_value(a) == 2
with config.change_flags(compute_test_value="off"):
# condition is always False
a = Assert()(c, c > 2)
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a)
get_underlying_scalar_constant_value(a)
# condition is not constant
a = Assert()(c, c > x)
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(a)
get_underlying_scalar_constant_value(a)
def test_second(self):
# Second should apply when the value is constant but not the shape
c = constant(np.random.random())
shp = vector()
s = at.second(shp, c)
assert get_scalar_constant_value(s) == c.data
assert get_underlying_scalar_constant_value(s) == c.data
def test_copy(self):
# Make sure we do not return the internal storage of a constant,
......@@ -3418,17 +3419,27 @@ class TestGetScalarConstantValue:
@pytest.mark.parametrize("only_process_constants", (True, False))
def test_None_and_NoneConst(self, only_process_constants):
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(
get_underlying_scalar_constant_value(
None, only_process_constants=only_process_constants
)
assert (
get_scalar_constant_value(
get_underlying_scalar_constant_value(
NoneConst, only_process_constants=only_process_constants
)
is None
)
@pytest.mark.parametrize(
["valid_inp", "invalid_inp"],
((np.array(4), np.zeros(5)), (at.constant(4), at.constant(3, ndim=1))),
)
def test_get_scalar_constant_value(valid_inp, invalid_inp):
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(invalid_inp)
assert get_scalar_constant_value(valid_inp) == 4
def test_complex_mod_failure():
# Make sure % fails on complex numbers.
x = vector(dtype="complex64")
......
......@@ -823,8 +823,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert len(res_shape) == 1
assert len(res_shape[0]) == 2
assert pytensor.get_scalar_constant_value(res_shape[0][0]) == 1
assert pytensor.get_scalar_constant_value(res_shape[0][1]) == 1
assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1
assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1
def test_multi_output(self):
class CustomElemwise(Elemwise):
......
......@@ -27,7 +27,7 @@ from pytensor.tensor.basic import (
as_tensor_variable,
constant,
eye,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
switch,
)
from pytensor.tensor.elemwise import CAReduce, Elemwise
......@@ -894,7 +894,7 @@ class TestMaxAndArgmax:
x = matrix()
cost = argmax(x, axis=0).sum()
gx = grad(cost, x)
val = get_scalar_constant_value(gx)
val = get_underlying_scalar_constant_value(gx)
assert val == 0.0
def test_grad(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论