提交 a120dc27 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Cache unique value of TensorConstants and deprecate `get_unique_constant_value`

上级 2b57f74f
......@@ -71,7 +71,7 @@ from pytensor.tensor.subtensor import (
get_slice_elements,
set_subtensor,
)
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
from pytensor.tensor.variable import TensorConstant
list_opt_slice = [
......@@ -136,10 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
all_ins = list(graph_inputs(op_outs))
for idx in range(op_info.n_seqs):
node_inp = node.inputs[idx + 1]
if (
isinstance(node_inp, TensorConstant)
and get_unique_constant_value(node_inp) is not None
):
if isinstance(node_inp, TensorConstant) and node_inp.unique_value is not None:
try:
# This works if input is a constant that has all entries
# equal
......
......@@ -491,6 +491,10 @@ class SparseConstant(SparseVariable, TensorConstant):
def __repr__(self):
return str(self)
@property
def unique_value(self):
return None
SparseTensorType.variable_type = SparseVariable
SparseTensorType.constant_type = SparseConstant
......
......@@ -19,7 +19,7 @@ from numpy.core.numeric import normalize_axis_tuple
import pytensor
import pytensor.scalar.sharedvar
from pytensor import compile, config, printing
from pytensor import config, printing
from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, grad_undefined
......@@ -35,7 +35,7 @@ from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
from pytensor.raise_op import CheckAndRaise, assert_op
from pytensor.scalar import int32
from pytensor.scalar.basic import ScalarConstant, ScalarVariable
from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable
from pytensor.tensor import (
_as_tensor_variable,
_get_vector_length,
......@@ -71,10 +71,10 @@ from pytensor.tensor.type import (
uint_dtypes,
values_eq_approx_always_true,
)
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import (
TensorConstant,
TensorVariable,
get_unique_constant_value,
)
......@@ -319,6 +319,8 @@ def get_underlying_scalar_constant_value(
but I'm not sure where it is.
"""
from pytensor.compile.ops import DeepCopyOp, OutputGuard
v = orig_v
while True:
if v is None:
......@@ -336,34 +338,22 @@ def get_underlying_scalar_constant_value(
raise NotScalarConstantError()
if isinstance(v, Constant):
unique_value = get_unique_constant_value(v)
if unique_value is not None:
data = unique_value
else:
data = v.data
if isinstance(v.type, TensorType) and v.unique_value is not None:
return v.unique_value
if isinstance(data, np.ndarray):
try:
return np.array(data.item(), dtype=v.dtype)
except ValueError:
raise NotScalarConstantError()
elif isinstance(v.type, ScalarType):
return v.data
from pytensor.sparse.type import SparseTensorType
elif isinstance(v.type, NoneTypeT):
return None
if isinstance(v.type, SparseTensorType):
raise NotScalarConstantError()
return data
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
max_recur -= 1
if isinstance(
v.owner.op,
Alloc
| DimShuffle
| Unbroadcast
| compile.ops.OutputGuard
| compile.DeepCopyOp,
Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp,
):
# OutputGuard is only used in debugmode but we
# keep it here to avoid problems with old pickles
......
......@@ -41,7 +41,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize,
)
from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value
from pytensor.tensor.variable import TensorConstant
class InplaceElemwiseOptimizer(GraphRewriter):
......@@ -513,7 +513,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
new_inputs.append(i)
else:
try:
# works only for scalars
cval_i = get_underlying_scalar_constant_value(
i, only_process_constants=True
)
......@@ -1218,11 +1217,13 @@ def local_inline_composite_constants(fgraph, node):
node.inputs, composite_op.fgraph.inputs, strict=True
):
# Complex variables don't have a `c_literal` that can be inlined
if "complex" not in outer_inp.type.dtype:
unique_value = get_unique_constant_value(outer_inp)
if unique_value is not None:
if (
isinstance(outer_inp, TensorConstant)
and "complex" not in outer_inp.type.dtype
):
if outer_inp.unique_value is not None:
inner_replacements[inner_inp] = ps.constant(
unique_value, dtype=inner_inp.dtype
outer_inp.unique_value, dtype=inner_inp.dtype
)
continue
new_outer_inputs.append(outer_inp)
......
......@@ -106,7 +106,6 @@ from pytensor.tensor.type import (
from pytensor.tensor.variable import (
TensorConstant,
TensorVariable,
get_unique_constant_value,
)
......@@ -138,16 +137,8 @@ def get_constant(v):
numeric constant. If v is a plain Variable, returns None.
"""
if isinstance(v, Constant):
unique_value = get_unique_constant_value(v)
if unique_value is not None:
data = unique_value
else:
data = v.data
if data.ndim == 0:
return data
else:
return None
if isinstance(v, TensorConstant):
return v.unique_value
elif isinstance(v, Variable):
return None
else:
......@@ -628,7 +619,14 @@ def local_mul_switch_sink(fgraph, node):
# Look for a zero as the first or second branch of the switch
for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch]
if not get_unique_constant_value(zero_switch_input) == 0.0:
if (
not get_underlying_scalar_constant_value(
zero_switch_input,
only_process_constants=True,
raise_not_constant=False,
)
== 0.0
):
continue
switch_cond = switch_node.inputs[0]
......@@ -685,7 +683,14 @@ def local_div_switch_sink(fgraph, node):
# Look for a zero as the first or second branch of the switch
for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch]
if not get_unique_constant_value(zero_switch_input) == 0.0:
if (
not get_underlying_scalar_constant_value(
zero_switch_input,
only_process_constants=True,
raise_not_constant=False,
)
== 0.0
):
continue
switch_cond = switch_node.inputs[0]
......
......@@ -20,7 +20,7 @@ from pytensor.tensor import basic as ptb
from pytensor.tensor.elemwise import get_normalized_batch_axes
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from pytensor.tensor.type_other import NoneConst
from pytensor.tensor.type_other import NoneConst, NoneTypeT
from pytensor.tensor.variable import TensorConstant, TensorVariable
......@@ -401,8 +401,6 @@ class SpecifyShape(COp):
_output_type_depends_on_input_value = True
def make_node(self, x, *shape):
from pytensor.tensor.basic import get_underlying_scalar_constant_value
x = ptb.as_tensor_variable(x)
shape = tuple(
......@@ -428,11 +426,9 @@ class SpecifyShape(COp):
for i, (xts, s) in enumerate(zip(x.type.shape, shape, strict=True)):
if xts is not None:
type_shape[i] = xts
else:
elif not isinstance(s.type, NoneTypeT):
try:
type_s = get_underlying_scalar_constant_value(s)
if type_s is not None:
type_shape[i] = int(type_s)
type_shape[i] = int(ptb.get_underlying_scalar_constant_value(s))
except NotScalarConstantError:
pass
......@@ -460,22 +456,13 @@ class SpecifyShape(COp):
def infer_shape(self, fgraph, node, shapes):
xshape, *_ = shapes
shape = node.inputs[1:]
new_shape = []
for dim in range(node.inputs[0].type.ndim):
s = shape[dim]
try:
s = ptb.get_underlying_scalar_constant_value(s)
# We assume that `None` shapes are always retrieved by
# `get_underlying_scalar_constant_value`, and only in that case do we default to
# the shape of the input variable
if s is None:
s = xshape[dim]
except NotScalarConstantError:
pass
new_shape.append(ptb.as_tensor_variable(s))
assert len(new_shape) == len(xshape)
return [new_shape]
# Use x shape if specified dim is None, otherwise the specified shape
return [
[
xshape[i] if isinstance(dim.type, NoneTypeT) else dim
for i, dim in enumerate(shape)
]
]
def connection_pattern(self, node):
return [[True], *[[False]] * len(node.inputs[1:])]
......
......@@ -11,7 +11,10 @@ from pytensor import tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, OptionalApplyType, Variable
from pytensor.graph.utils import MetaType
from pytensor.scalar import ComplexError, IntegerDivisionError
from pytensor.scalar import (
ComplexError,
IntegerDivisionError,
)
from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError
from pytensor.tensor.type import TensorType
......@@ -1042,17 +1045,9 @@ class TensorConstantSignature(tuple):
def get_unique_constant_value(x: TensorVariable) -> Number | None:
"""Return the unique value of a tensor, if there is one"""
if isinstance(x, Constant):
data = x.data
if isinstance(data, np.ndarray) and data.size > 0:
if data.size == 1:
return data.squeeze()
flat_data = data.ravel()
if (flat_data == flat_data[0]).all():
return flat_data[0]
warnings.warn("get_unique_constant_value is deprecated.", FutureWarning)
if isinstance(x, TensorConstant):
return x.unique_value
return None
......@@ -1081,6 +1076,30 @@ class TensorConstant(TensorVariable, Constant[_TensorTypeType]):
def signature(self):
return TensorConstantSignature((self.type, self.data))
@property
def unique_value(self) -> Number | None:
"""Return the unique value of a tensor, if there is one"""
try:
return self._unique_value
except AttributeError:
data = self.data
unique_value = None
if data.size > 0:
if data.size == 1:
unique_value = data.squeeze()
else:
flat_data = data.ravel()
if (flat_data == flat_data[0]).all():
unique_value = flat_data[0]
if unique_value is not None:
# Don't allow the unique value to be changed
unique_value.setflags(write=False)
self._unique_value = unique_value
return self._unique_value
def equals(self, other):
# Override Constant.equals to allow to compare with
# numpy.ndarray, and python type.
......
......@@ -3571,10 +3571,11 @@ class TestGetUnderlyingScalarConstantValue:
assert get_underlying_scalar_constant_value(s) == c.data
def test_copy(self):
# Make sure we do not return the internal storage of a constant,
# Make sure we do not return a writeable internal storage of a constant,
# so we cannot change the value of a constant by mistake.
c = constant(3)
d = extract_constant(c)
with pytest.raises(ValueError, match="output array is read-only"):
d += 1
e = extract_constant(c)
assert e == 3, (c, d, e)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论