提交 55f3cd0c authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Use more strict `get_scalar_constant_value` when the input must be a scalar

上级 32aadc8c
...@@ -18,7 +18,7 @@ from pytensor.tensor.basic import ( ...@@ -18,7 +18,7 @@ from pytensor.tensor.basic import (
Split, Split,
TensorFromScalar, TensorFromScalar,
Tri, Tri,
get_underlying_scalar_constant_value, get_scalar_constant_value,
) )
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.shape import Shape_i from pytensor.tensor.shape import Shape_i
...@@ -103,7 +103,7 @@ def jax_funcify_Join(op, **kwargs): ...@@ -103,7 +103,7 @@ def jax_funcify_Join(op, **kwargs):
def jax_funcify_Split(op: Split, node, **kwargs): def jax_funcify_Split(op: Split, node, **kwargs):
_, axis, splits = node.inputs _, axis, splits = node.inputs
try: try:
constant_axis = get_underlying_scalar_constant_value(axis) constant_axis = get_scalar_constant_value(axis)
except NotScalarConstantError: except NotScalarConstantError:
constant_axis = None constant_axis = None
warnings.warn( warnings.warn(
...@@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs): ...@@ -113,7 +113,7 @@ def jax_funcify_Split(op: Split, node, **kwargs):
try: try:
constant_splits = np.array( constant_splits = np.array(
[ [
get_underlying_scalar_constant_value(splits[i]) get_scalar_constant_value(splits[i])
for i in range(get_vector_length(splits)) for i in range(get_vector_length(splits))
] ]
) )
......
...@@ -484,7 +484,7 @@ def scan( ...@@ -484,7 +484,7 @@ def scan(
n_fixed_steps = int(n_steps) n_fixed_steps = int(n_steps)
else: else:
try: try:
n_fixed_steps = pt.get_underlying_scalar_constant_value(n_steps) n_fixed_steps = pt.get_scalar_constant_value(n_steps)
except NotScalarConstantError: except NotScalarConstantError:
n_fixed_steps = None n_fixed_steps = None
......
...@@ -55,7 +55,6 @@ from pytensor.tensor.basic import ( ...@@ -55,7 +55,6 @@ from pytensor.tensor.basic import (
Alloc, Alloc,
AllocEmpty, AllocEmpty,
get_scalar_constant_value, get_scalar_constant_value,
get_underlying_scalar_constant_value,
) )
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
...@@ -1976,13 +1975,13 @@ class ScanMerge(GraphRewriter): ...@@ -1976,13 +1975,13 @@ class ScanMerge(GraphRewriter):
nsteps = node.inputs[0] nsteps = node.inputs[0]
try: try:
nsteps = int(get_underlying_scalar_constant_value(nsteps)) nsteps = int(get_scalar_constant_value(nsteps))
except NotScalarConstantError: except NotScalarConstantError:
pass pass
rep_nsteps = rep_node.inputs[0] rep_nsteps = rep_node.inputs[0]
try: try:
rep_nsteps = int(get_underlying_scalar_constant_value(rep_nsteps)) rep_nsteps = int(get_scalar_constant_value(rep_nsteps))
except NotScalarConstantError: except NotScalarConstantError:
pass pass
......
...@@ -1808,7 +1808,7 @@ pprint.assign(alloc, printing.FunctionPrinter(["alloc"])) ...@@ -1808,7 +1808,7 @@ pprint.assign(alloc, printing.FunctionPrinter(["alloc"]))
@_get_vector_length.register(Alloc) @_get_vector_length.register(Alloc)
def _get_vector_length_Alloc(var_inst, var): def _get_vector_length_Alloc(var_inst, var):
try: try:
return get_underlying_scalar_constant_value(var.owner.inputs[1]) return get_scalar_constant_value(var.owner.inputs[1])
except NotScalarConstantError: except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined") raise ValueError(f"Length of {var} cannot be determined")
...@@ -2509,7 +2509,7 @@ class Join(COp): ...@@ -2509,7 +2509,7 @@ class Join(COp):
if not isinstance(axis, int): if not isinstance(axis, int):
try: try:
axis = int(get_underlying_scalar_constant_value(axis)) axis = int(get_scalar_constant_value(axis))
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -2753,7 +2753,7 @@ pprint.assign(Join, printing.FunctionPrinter(["join"])) ...@@ -2753,7 +2753,7 @@ pprint.assign(Join, printing.FunctionPrinter(["join"]))
def _get_vector_length_Join(op, var): def _get_vector_length_Join(op, var):
axis, *arrays = var.owner.inputs axis, *arrays = var.owner.inputs
try: try:
axis = get_underlying_scalar_constant_value(axis) axis = get_scalar_constant_value(axis)
assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays) assert axis == 0 and builtins.all(a.ndim == 1 for a in arrays)
return builtins.sum(get_vector_length(a) for a in arrays) return builtins.sum(get_vector_length(a) for a in arrays)
except NotScalarConstantError: except NotScalarConstantError:
...@@ -4146,7 +4146,7 @@ class Choose(Op): ...@@ -4146,7 +4146,7 @@ class Choose(Op):
static_out_shape = () static_out_shape = ()
for s in out_shape: for s in out_shape:
try: try:
s_val = get_underlying_scalar_constant_value(s) s_val = get_scalar_constant_value(s)
except (NotScalarConstantError, AttributeError): except (NotScalarConstantError, AttributeError):
s_val = None s_val = None
......
...@@ -25,7 +25,7 @@ from pytensor.graph.op import Op ...@@ -25,7 +25,7 @@ from pytensor.graph.op import Op
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
as_tensor_variable, as_tensor_variable,
get_underlying_scalar_constant_value, get_scalar_constant_value,
) )
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -497,8 +497,8 @@ def check_conv_gradinputs_shape( ...@@ -497,8 +497,8 @@ def check_conv_gradinputs_shape(
if given is None or computed is None: if given is None or computed is None:
return True return True
try: try:
given = get_underlying_scalar_constant_value(given) given = get_scalar_constant_value(given)
computed = get_underlying_scalar_constant_value(computed) computed = get_scalar_constant_value(computed)
return int(given) == int(computed) return int(given) == int(computed)
except NotScalarConstantError: except NotScalarConstantError:
# no answer possible, accept for now # no answer possible, accept for now
...@@ -534,7 +534,7 @@ def assert_conv_shape(shape): ...@@ -534,7 +534,7 @@ def assert_conv_shape(shape):
out_shape = [] out_shape = []
for i, n in enumerate(shape): for i, n in enumerate(shape):
try: try:
const_n = get_underlying_scalar_constant_value(n) const_n = get_scalar_constant_value(n)
if i < 2: if i < 2:
if const_n < 0: if const_n < 0:
raise ValueError( raise ValueError(
...@@ -2203,9 +2203,7 @@ class BaseAbstractConv(Op): ...@@ -2203,9 +2203,7 @@ class BaseAbstractConv(Op):
if imshp_i is not None: if imshp_i is not None:
# Components of imshp should be constant or ints # Components of imshp should be constant or ints
try: try:
get_underlying_scalar_constant_value( get_scalar_constant_value(imshp_i, only_process_constants=True)
imshp_i, only_process_constants=True
)
except NotScalarConstantError: except NotScalarConstantError:
raise ValueError( raise ValueError(
"imshp should be None or a tuple of constant int values" "imshp should be None or a tuple of constant int values"
...@@ -2218,9 +2216,7 @@ class BaseAbstractConv(Op): ...@@ -2218,9 +2216,7 @@ class BaseAbstractConv(Op):
if kshp_i is not None: if kshp_i is not None:
# Components of kshp should be constant or ints # Components of kshp should be constant or ints
try: try:
get_underlying_scalar_constant_value( get_scalar_constant_value(kshp_i, only_process_constants=True)
kshp_i, only_process_constants=True
)
except NotScalarConstantError: except NotScalarConstantError:
raise ValueError( raise ValueError(
"kshp should be None or a tuple of constant int values" "kshp should be None or a tuple of constant int values"
......
...@@ -678,7 +678,7 @@ class Repeat(Op): ...@@ -678,7 +678,7 @@ class Repeat(Op):
out_shape = [None] out_shape = [None]
else: else:
try: try:
const_reps = ptb.get_underlying_scalar_constant_value(repeats) const_reps = ptb.get_scalar_constant_value(repeats)
except NotScalarConstantError: except NotScalarConstantError:
const_reps = None const_reps = None
if const_reps == 1: if const_reps == 1:
......
...@@ -57,7 +57,6 @@ from pytensor.tensor.basic import ( ...@@ -57,7 +57,6 @@ from pytensor.tensor.basic import (
cast, cast,
fill, fill,
get_scalar_constant_value, get_scalar_constant_value,
get_underlying_scalar_constant_value,
join, join,
ones_like, ones_like,
register_infer_shape, register_infer_shape,
...@@ -739,7 +738,7 @@ def local_remove_useless_assert(fgraph, node): ...@@ -739,7 +738,7 @@ def local_remove_useless_assert(fgraph, node):
n_conds = len(node.inputs[1:]) n_conds = len(node.inputs[1:])
for c in node.inputs[1:]: for c in node.inputs[1:]:
try: try:
const = get_underlying_scalar_constant_value(c) const = get_scalar_constant_value(c)
if 0 != const.ndim or const == 0: if 0 != const.ndim or const == 0:
# Should we raise an error here? How to be sure it # Should we raise an error here? How to be sure it
...@@ -834,7 +833,7 @@ def local_join_empty(fgraph, node): ...@@ -834,7 +833,7 @@ def local_join_empty(fgraph, node):
return return
new_inputs = [] new_inputs = []
try: try:
join_idx = get_underlying_scalar_constant_value( join_idx = get_scalar_constant_value(
node.inputs[0], only_process_constants=True node.inputs[0], only_process_constants=True
) )
except NotScalarConstantError: except NotScalarConstantError:
......
...@@ -153,18 +153,16 @@ def local_0_dot_x(fgraph, node): ...@@ -153,18 +153,16 @@ def local_0_dot_x(fgraph, node):
x = node.inputs[0] x = node.inputs[0]
y = node.inputs[1] y = node.inputs[1]
replace = False replace = (
try: get_underlying_scalar_constant_value(
if get_underlying_scalar_constant_value(x, only_process_constants=True) == 0: x, only_process_constants=True, raise_not_constant=False
replace = True )
except NotScalarConstantError: == 0
pass or get_underlying_scalar_constant_value(
y, only_process_constants=True, raise_not_constant=False
try: )
if get_underlying_scalar_constant_value(y, only_process_constants=True) == 0: == 0
replace = True )
except NotScalarConstantError:
pass
if replace: if replace:
constant_zero = constant(0, dtype=node.outputs[0].type.dtype) constant_zero = constant(0, dtype=node.outputs[0].type.dtype)
...@@ -2111,7 +2109,7 @@ def local_add_remove_zeros(fgraph, node): ...@@ -2111,7 +2109,7 @@ def local_add_remove_zeros(fgraph, node):
y = get_underlying_scalar_constant_value(inp) y = get_underlying_scalar_constant_value(inp)
except NotScalarConstantError: except NotScalarConstantError:
y = inp y = inp
if np.all(y == 0.0): if y == 0.0:
continue continue
new_inputs.append(inp) new_inputs.append(inp)
...@@ -2209,7 +2207,7 @@ def local_abs_merge(fgraph, node): ...@@ -2209,7 +2207,7 @@ def local_abs_merge(fgraph, node):
) )
except NotScalarConstantError: except NotScalarConstantError:
return False return False
if not (const >= 0).all(): if not const >= 0:
return False return False
inputs.append(i) inputs.append(i)
else: else:
...@@ -2861,7 +2859,7 @@ def _is_1(expr): ...@@ -2861,7 +2859,7 @@ def _is_1(expr):
""" """
try: try:
v = get_underlying_scalar_constant_value(expr) v = get_underlying_scalar_constant_value(expr)
return np.allclose(v, 1) return np.isclose(v, 1)
except NotScalarConstantError: except NotScalarConstantError:
return False return False
...@@ -3029,7 +3027,7 @@ def is_neg(var): ...@@ -3029,7 +3027,7 @@ def is_neg(var):
for idx, mul_input in enumerate(var_node.inputs): for idx, mul_input in enumerate(var_node.inputs):
try: try:
constant = get_underlying_scalar_constant_value(mul_input) constant = get_underlying_scalar_constant_value(mul_input)
is_minus_1 = np.allclose(constant, -1) is_minus_1 = np.isclose(constant, -1)
except NotScalarConstantError: except NotScalarConstantError:
is_minus_1 = False is_minus_1 = False
if is_minus_1: if is_minus_1:
......
...@@ -23,7 +23,6 @@ from pytensor.tensor.basic import ( ...@@ -23,7 +23,6 @@ from pytensor.tensor.basic import (
cast, cast,
constant, constant,
get_scalar_constant_value, get_scalar_constant_value,
get_underlying_scalar_constant_value,
register_infer_shape, register_infer_shape,
stack, stack,
) )
...@@ -213,7 +212,7 @@ class ShapeFeature(Feature): ...@@ -213,7 +212,7 @@ class ShapeFeature(Feature):
# Do not call make_node for test_value # Do not call make_node for test_value
s = Shape_i(i)(r) s = Shape_i(i)(r)
try: try:
s = get_underlying_scalar_constant_value(s) s = get_scalar_constant_value(s)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
return s return s
...@@ -297,7 +296,7 @@ class ShapeFeature(Feature): ...@@ -297,7 +296,7 @@ class ShapeFeature(Feature):
assert len(idx) == 1 assert len(idx) == 1
idx = idx[0] idx = idx[0]
try: try:
i = get_underlying_scalar_constant_value(idx) i = get_scalar_constant_value(idx)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
else: else:
...@@ -452,7 +451,7 @@ class ShapeFeature(Feature): ...@@ -452,7 +451,7 @@ class ShapeFeature(Feature):
) )
or self.lscalar_one.equals(merged_shape[i]) or self.lscalar_one.equals(merged_shape[i])
or self.lscalar_one.equals( or self.lscalar_one.equals(
get_underlying_scalar_constant_value( get_scalar_constant_value(
merged_shape[i], merged_shape[i],
only_process_constants=True, only_process_constants=True,
raise_not_constant=False, raise_not_constant=False,
...@@ -481,9 +480,7 @@ class ShapeFeature(Feature): ...@@ -481,9 +480,7 @@ class ShapeFeature(Feature):
or r.type.shape[idx] != 1 or r.type.shape[idx] != 1
or self.lscalar_one.equals(new_shape[idx]) or self.lscalar_one.equals(new_shape[idx])
or self.lscalar_one.equals( or self.lscalar_one.equals(
get_underlying_scalar_constant_value( get_scalar_constant_value(new_shape[idx], raise_not_constant=False)
new_shape[idx], raise_not_constant=False
)
) )
for idx in range(r.type.ndim) for idx in range(r.type.ndim)
) )
......
...@@ -999,7 +999,7 @@ def local_useless_subtensor(fgraph, node): ...@@ -999,7 +999,7 @@ def local_useless_subtensor(fgraph, node):
if isinstance(idx.stop, int | np.integer): if isinstance(idx.stop, int | np.integer):
length_pos_data = sys.maxsize length_pos_data = sys.maxsize
try: try:
length_pos_data = get_underlying_scalar_constant_value( length_pos_data = get_scalar_constant_value(
length_pos, only_process_constants=True length_pos, only_process_constants=True
) )
except NotScalarConstantError: except NotScalarConstantError:
...@@ -1064,7 +1064,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node): ...@@ -1064,7 +1064,7 @@ def local_useless_AdvancedSubtensor1(fgraph, node):
# get length of the indexed tensor along the first axis # get length of the indexed tensor along the first axis
try: try:
length = get_underlying_scalar_constant_value( length = get_scalar_constant_value(
shape_of[node.inputs[0]][0], only_process_constants=True shape_of[node.inputs[0]][0], only_process_constants=True
) )
except NotScalarConstantError: except NotScalarConstantError:
...@@ -1736,7 +1736,7 @@ def local_join_subtensors(fgraph, node): ...@@ -1736,7 +1736,7 @@ def local_join_subtensors(fgraph, node):
axis, tensors = node.inputs[0], node.inputs[1:] axis, tensors = node.inputs[0], node.inputs[1:]
try: try:
axis = get_underlying_scalar_constant_value(axis) axis = get_scalar_constant_value(axis)
except NotScalarConstantError: except NotScalarConstantError:
return return
...@@ -1797,12 +1797,7 @@ def local_join_subtensors(fgraph, node): ...@@ -1797,12 +1797,7 @@ def local_join_subtensors(fgraph, node):
if step is None: if step is None:
continue continue
try: try:
if ( if get_scalar_constant_value(step, only_process_constants=True) != 1:
get_underlying_scalar_constant_value(
step, only_process_constants=True
)
!= 1
):
return None return None
except NotScalarConstantError: except NotScalarConstantError:
return None return None
......
...@@ -428,7 +428,7 @@ class SpecifyShape(COp): ...@@ -428,7 +428,7 @@ class SpecifyShape(COp):
type_shape[i] = xts type_shape[i] = xts
elif not isinstance(s.type, NoneTypeT): elif not isinstance(s.type, NoneTypeT):
try: try:
type_shape[i] = int(ptb.get_underlying_scalar_constant_value(s)) type_shape[i] = int(ptb.get_scalar_constant_value(s))
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -580,7 +580,7 @@ def specify_shape( ...@@ -580,7 +580,7 @@ def specify_shape(
@_get_vector_length.register(SpecifyShape) # type: ignore @_get_vector_length.register(SpecifyShape) # type: ignore
def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int: def _get_vector_length_SpecifyShape(op: Op, var: TensorVariable) -> int:
try: try:
return int(ptb.get_underlying_scalar_constant_value(var.owner.inputs[1]).item()) return int(ptb.get_scalar_constant_value(var.owner.inputs[1]).item())
except NotScalarConstantError: except NotScalarConstantError:
raise ValueError(f"Length of {var} cannot be determined") raise ValueError(f"Length of {var} cannot be determined")
...@@ -661,7 +661,7 @@ class Reshape(COp): ...@@ -661,7 +661,7 @@ class Reshape(COp):
y = shp_list[index] y = shp_list[index]
y = ptb.as_tensor_variable(y) y = ptb.as_tensor_variable(y)
try: try:
s_val = ptb.get_underlying_scalar_constant_value(y).item() s_val = ptb.get_scalar_constant_value(y).item()
if s_val >= 0: if s_val >= 0:
out_shape[index] = s_val out_shape[index] = s_val
except NotScalarConstantError: except NotScalarConstantError:
......
...@@ -29,7 +29,7 @@ from pytensor.tensor import ( ...@@ -29,7 +29,7 @@ from pytensor.tensor import (
from pytensor.tensor.basic import ( from pytensor.tensor.basic import (
ScalarFromTensor, ScalarFromTensor,
alloc, alloc,
get_underlying_scalar_constant_value, get_scalar_constant_value,
nonzero, nonzero,
scalar_from_tensor, scalar_from_tensor,
) )
...@@ -778,7 +778,7 @@ def get_constant_idx( ...@@ -778,7 +778,7 @@ def get_constant_idx(
return slice(conv(val.start), conv(val.stop), conv(val.step)) return slice(conv(val.start), conv(val.stop), conv(val.step))
else: else:
try: try:
return get_underlying_scalar_constant_value( return get_scalar_constant_value(
val, val,
only_process_constants=only_process_constants, only_process_constants=only_process_constants,
elemwise=elemwise, elemwise=elemwise,
...@@ -855,7 +855,7 @@ class Subtensor(COp): ...@@ -855,7 +855,7 @@ class Subtensor(COp):
if value is None: if value is None:
return value, True return value, True
try: try:
value = get_underlying_scalar_constant_value(value) value = get_scalar_constant_value(value)
return value, True return value, True
except NotScalarConstantError: except NotScalarConstantError:
return value, False return value, False
...@@ -3022,17 +3022,17 @@ def _get_vector_length_Subtensor(op, var): ...@@ -3022,17 +3022,17 @@ def _get_vector_length_Subtensor(op, var):
start = ( start = (
None None
if indices[0].start is None if indices[0].start is None
else get_underlying_scalar_constant_value(indices[0].start) else get_scalar_constant_value(indices[0].start)
) )
stop = ( stop = (
None None
if indices[0].stop is None if indices[0].stop is None
else get_underlying_scalar_constant_value(indices[0].stop) else get_scalar_constant_value(indices[0].stop)
) )
step = ( step = (
None None
if indices[0].step is None if indices[0].step is None
else get_underlying_scalar_constant_value(indices[0].step) else get_scalar_constant_value(indices[0].step)
) )
if start == stop: if start == stop:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论