提交 55f3cd0c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use more strict `get_scalar_constant_value` when the input must be a scalar

上级 32aadc8c
......@@ -18,7 +18,7 @@ from pytensor.tensor.basic import (
Split,
TensorFromScalar,
Tri,
get_underlying_scalar_constant_value,
get_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape_i
......@@ -103,7 +103,7 @@ def jax_funcify_Join(op, **kwargs):
def jax_funcify_Split(op: Split, node, **kwargs):
_, axis, splits = node.inputs
try:
constant_axis = get_underlying_scalar_constant_value(axis)
constant_axis = get_scalar_constant_value(axis)
except NotScalarConstantError:
constant_axis = None
warnings.warn(
......@@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
try:
constant_splits = np.array(
[
get_underlying_scalar_constant_value(splits[i])
get_scalar_constant_value(splits[i])
for i in range(get_vector_length(splits))
]
)
......
......@@ -484,7 +484,7 @@ def scan(
n_fixed_steps = int(n_steps)
else:
try:
n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps)
n_fixed_steps = pt.get_scalar_constant_value(n_steps)
except NotScalarConstantError:
n_fixed_steps = None
......
......@@ -55,7 +55,6 @@ from pytensor.tensor.basic import (
Alloc,
AllocEmpty,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
)
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError
......@@ -1976,13 +1975,13 @@ class ScanMerge(GraphRewriter):
nsteps = node.inputs[0]
try:
nsteps = int(get_underlying_scalar_constant_value(nsteps))
nsteps = int(get_scalar_constant_value(nsteps))
except NotScalarConstantError:
pass
rep_nsteps = rep_node.inputs[0]
try:
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps))
rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
except NotScalarConstantError:
pass
......
......@@ -1808,7 +1808,7 @@ pprint.assign(alloc, printing.FunctionPrinter(["alloc"]))
@_get_vector_length.register(Alloc)
def _get_vector_length_Alloc(var_inst, var):
try:
return get_underlying_scalar_constant_value(var.owner.inputs[1])
return get_scalar_constant_value(var.owner.inputs[1])
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")
......@@ -2509,7 +2509,7 @@ class Join(COp):
if not isinstance(axis, int):
try:
axis = int(get_underlying_scalar_constant_value(axis))
axis = int(get_scalar_constant_value(axis))
except NotScalarConstantError:
pass
......@@ -2753,7 +2753,7 @@ pprint.assign(Join, printing.FunctionPrinter(["join"]))
def _get_vector_length_Join(op, var):
axis, *arrays = var.owner.inputs
try:
axis = get_underlying_scalar_constant_value(axis)
axis = get_scalar_constant_value(axis)
assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays)
return builtins.sum(get_vector_length(a) for a in arrays)
except NotScalarConstantError:
......@@ -4146,7 +4146,7 @@ class Choose(Op):
static_out_shape = ()
for s in out_shape:
try:
s_val = get_underlying_scalar_constant_value(s)
s_val = get_scalar_constant_value(s)
except (NotScalarConstantError, AttributeError):
s_val = None
......
......@@ -25,7 +25,7 @@ from pytensor.graph.op import Op
from pytensor.raise_op import Assert
from pytensor.tensor.basic import (
as_tensor_variable,
get_underlying_scalar_constant_value,
get_scalar_constant_value,
)
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -497,8 +497,8 @@ def check_conv_gradinputs_shape(
if given is None or computed is None:
return True
try:
given = get_underlying_scalar_constant_value(given)
computed = get_underlying_scalar_constant_value(computed)
given = get_scalar_constant_value(given)
computed = get_scalar_constant_value(computed)
return int(given) == int(computed)
except NotScalarConstantError:
# no answer possible, accept for now
......@@ -534,7 +534,7 @@ def assert_conv_shape(shape):
out_shape = []
for i, n in enumerate(shape):
try:
const_n = get_underlying_scalar_constant_value(n)
const_n = get_scalar_constant_value(n)
if i < 2:
if const_n < 0:
raise ValueError(
......@@ -2203,9 +2203,7 @@ class BaseAbstractConv(Op):
if imshp_i is not None:
# Components of imshp should be constant or ints
try:
get_underlying_scalar_constant_value(
imshp_i, only_process_constants=True
)
get_scalar_constant_value(imshp_i, only_process_constants=True)
except NotScalarConstantError:
raise ValueError(
"imshp should be None or a tuple of constant int values"
......@@ -2218,9 +2216,7 @@ class BaseAbstractConv(Op):
if kshp_i is not None:
# Components of kshp should be constant or ints
try:
get_underlying_scalar_constant_value(
kshp_i, only_process_constants=True
)
get_scalar_constant_value(kshp_i, only_process_constants=True)
except NotScalarConstantError:
raise ValueError(
"kshp should be None or a tuple of constant int values"
......
......@@ -678,7 +678,7 @@ class Repeat(Op):
out_shape = [None]
else:
try:
const_reps = ptb.get_underlying_scalar_constant_value(repeats)
const_reps = ptb.get_scalar_constant_value(repeats)
except NotScalarConstantError:
const_reps = None
if const_reps == 1:
......
......@@ -57,7 +57,6 @@ from pytensor.tensor.basic import (
cast,
fill,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
join,
ones_like,
register_infer_shape,
......@@ -739,7 +738,7 @@ def local_remove_useless_assert(fgraph, node):
n_conds = len(node.inputs[1:])
for c in node.inputs[1:]:
try:
const = get_underlying_scalar_constant_value(c)
const = get_scalar_constant_value(c)
if 0 != const.ndim or const == 0:
# Should we raise an error here? How to be sure it
......@@ -834,7 +833,7 @@ def local_join_empty(fgraph, node):
return
new_inputs = []
try:
join_idx = get_underlying_scalar_constant_value(
join_idx = get_scalar_constant_value(
node.inputs[0], only_process_constants=True
)
except NotScalarConstantError:
......
......@@ -153,18 +153,16 @@ def local_0_dot_x(fgraph, node):
x = node.inputs[0]
y = node.inputs[1]
replace = False
try:
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
try:
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0:
replace = True
except NotScalarConstantError:
pass
replace = (
get_underlying_scalar_constant_value(
x, only_process_constants=True, raise_not_constant=False
)
== 0
or get_underlying_scalar_constant_value(
y, only_process_constants=True, raise_not_constant=False
)
== 0
)
if replace:
constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
......@@ -2111,7 +2109,7 @@ def local_add_remove_zeros(fgraph, node):
y = get_underlying_scalar_constant_value(inp)
except NotScalarConstantError:
y = inp
if np.all(y == 0.0):
if y == 0.0:
continue
new_inputs.append(inp)
......@@ -2209,7 +2207,7 @@ def local_abs_merge(fgraph, node):
)
except NotScalarConstantError:
return False
if not (const >= 0).all():
if not const >= 0:
return False
inputs.append(i)
else:
......@@ -2861,7 +2859,7 @@ def _is_1(expr):
"""
try:
v = get_underlying_scalar_constant_value(expr)
return np.allclose(v, 1)
return np.isclose(v, 1)
except NotScalarConstantError:
return False
......@@ -3029,7 +3027,7 @@ def is_neg(var):
for idx, mul_input in enumerate(var_node.inputs):
try:
constant = get_underlying_scalar_constant_value(mul_input)
is_minus_1 = np.allclose(constant, -1)
is_minus_1 = np.isclose(constant, -1)
except NotScalarConstantError:
is_minus_1 = False
if is_minus_1:
......
......@@ -23,7 +23,6 @@ from pytensor.tensor.basic import (
cast,
constant,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
register_infer_shape,
stack,
)
......@@ -213,7 +212,7 @@ class ShapeFeature(Feature):
# Do not call make_node for test_value
s = Shape_i(i)(r)
try:
s = get_underlying_scalar_constant_value(s)
s = get_scalar_constant_value(s)
except NotScalarConstantError:
pass
return s
......@@ -297,7 +296,7 @@ class ShapeFeature(Feature):
assert len(idx) == 1
idx = idx[0]
try:
i = get_underlying_scalar_constant_value(idx)
i = get_scalar_constant_value(idx)
except NotScalarConstantError:
pass
else:
......@@ -452,7 +451,7 @@ class ShapeFeature(Feature):
)
or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals(
get_underlying_scalar_constant_value(
get_scalar_constant_value(
merged_shape[i],
only_process_constants=True,
raise_not_constant=False,
......@@ -481,9 +480,7 @@ class ShapeFeature(Feature):
or r.type.shape[idx] != 1
or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals(
get_underlying_scalar_constant_value(
new_shape[idx], raise_not_constant=False
)
get_scalar_constant_value(new_shape[idx], raise_not_constant=False)
)
for idx in range(r.type.ndim)
)
......
......@@ -999,7 +999,7 @@ def local_useless_subtensor(fgraph, node):
if isinstance(idx.stop, int | np.integer):
length_pos_data = sys.maxsize
try:
length_pos_data = get_underlying_scalar_constant_value(
length_pos_data = get_scalar_constant_value(
length_pos, only_process_constants=True
)
except NotScalarConstantError:
......@@ -1064,7 +1064,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node):
# get length of the indexed tensor along the first axis
try:
length = get_underlying_scalar_constant_value(
length = get_scalar_constant_value(
shape_of[node.inputs[0]][0], only_process_constants=True
)
except NotScalarConstantError:
......@@ -1736,7 +1736,7 @@ def local_join_subtensors(fgraph, node):
axis, tensors = node.inputs[0], node.inputs[1:]
try:
axis = get_underlying_scalar_constant_value(axis)
axis = get_scalar_constant_value(axis)
except NotScalarConstantError:
return
......@@ -1797,12 +1797,7 @@ def local_join_subtensors(fgraph, node):
if step is None:
continue
try:
if (
get_underlying_scalar_constant_value(
step, only_process_constants=True
)
!= 1
):
if get_scalar_constant_value(step, only_process_constants=True) != 1:
return None
except NotScalarConstantError:
return None
......
......@@ -428,7 +428,7 @@ class SpecifyShape(COp):
type_shape[i] = xts
elif not isinstance(s.type, NoneTypeT):
try:
type_shape[i] = int(ptb.get_underlying_scalar_constant_value(s))
type_shape[i] = int(ptb.get_scalar_constant_value(s))
except NotScalarConstantError:
pass
......@@ -580,7 +580,7 @@ def specify_shape(
@_get_vector_length.register(SpecifyShape) # type: ignore
def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int:
try:
return int(ptb.get_underlying_scalar_constant_value(var.owner.inputs[1]).item())
return int(ptb.get_scalar_constant_value(var.owner.inputs[1]).item())
except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined")
......@@ -661,7 +661,7 @@ class Reshape(COp):
y = shp_list[index]
y = ptb.as_tensor_variable(y)
try:
s_val = ptb.get_underlying_scalar_constant_value(y).item()
s_val = ptb.get_scalar_constant_value(y).item()
if s_val >= 0:
out_shape[index] = s_val
except NotScalarConstantError:
......
......@@ -29,7 +29,7 @@ from pytensor.tensor import (
from pytensor.tensor.basic import (
ScalarFromTensor,
alloc,
get_underlying_scalar_constant_value,
get_scalar_constant_value,
nonzero,
scalar_from_tensor,
)
......@@ -778,7 +778,7 @@ def get_constant_idx(
return slice(conv(val.start), conv(val.stop), conv(val.step))
else:
try:
return get_underlying_scalar_constant_value(
return get_scalar_constant_value(
val,
only_process_constants=only_process_constants,
elemwise=elemwise,
......@@ -855,7 +855,7 @@ class Subtensor(COp):
if value is None:
return value, True
try:
value = get_underlying_scalar_constant_value(value)
value = get_scalar_constant_value(value)
return value, True
except NotScalarConstantError:
return value, False
......@@ -3022,17 +3022,17 @@ def _get_vector_length_Subtensor(op, var):
start = (
None
if indices[0].start is None
else get_underlying_scalar_constant_value(indices[0].start)
else get_scalar_constant_value(indices[0].start)
)
stop = (
None
if indices[0].stop is None
else get_underlying_scalar_constant_value(indices[0].stop)
else get_scalar_constant_value(indices[0].stop)
)
step = (
None
if indices[0].step is None
else get_underlying_scalar_constant_value(indices[0].step)
else get_scalar_constant_value(indices[0].step)
)
if start == stop:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论