提交 a120dc27 authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Cache unique value of TensorConstants and deprecate `get_unique_constant_value`

上级 2b57f74f
...@@ -71,7 +71,7 @@ from pytensor.tensor.subtensor import ( ...@@ -71,7 +71,7 @@ from pytensor.tensor.subtensor import (
get_slice_elements, get_slice_elements,
set_subtensor, set_subtensor,
) )
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value from pytensor.tensor.variable import TensorConstant
list_opt_slice = [ list_opt_slice = [
...@@ -136,10 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -136,10 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
all_ins = list(graph_inputs(op_outs)) all_ins = list(graph_inputs(op_outs))
for idx in range(op_info.n_seqs): for idx in range(op_info.n_seqs):
node_inp = node.inputs[idx + 1] node_inp = node.inputs[idx + 1]
if ( if isinstance(node_inp, TensorConstant) and node_inp.unique_value is not None:
isinstance(node_inp, TensorConstant)
and get_unique_constant_value(node_inp) is not None
):
try: try:
# This works if input is a constant that has all entries # This works if input is a constant that has all entries
# equal # equal
......
...@@ -491,6 +491,10 @@ class SparseConstant(SparseVariable, TensorConstant): ...@@ -491,6 +491,10 @@ class SparseConstant(SparseVariable, TensorConstant):
def __repr__(self): def __repr__(self):
return str(self) return str(self)
@property
def unique_value(self):
return None
SparseTensorType.variable_type = SparseVariable SparseTensorType.variable_type = SparseVariable
SparseTensorType.constant_type = SparseConstant SparseTensorType.constant_type = SparseConstant
......
...@@ -19,7 +19,7 @@ from numpy.core.numeric import normalize_axis_tuple ...@@ -19,7 +19,7 @@ from numpy.core.numeric import normalize_axis_tuple
import pytensor import pytensor
import pytensor.scalar.sharedvar import pytensor.scalar.sharedvar
from pytensor import compile, config, printing from pytensor import config, printing
from pytensor import scalar as ps from pytensor import scalar as ps
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.gradient import DisconnectedType, grad_undefined from pytensor.gradient import DisconnectedType, grad_undefined
...@@ -35,7 +35,7 @@ from pytensor.link.c.params_type import ParamsType ...@@ -35,7 +35,7 @@ from pytensor.link.c.params_type import ParamsType
from pytensor.printing import Printer, min_informative_str, pprint, set_precedence from pytensor.printing import Printer, min_informative_str, pprint, set_precedence
from pytensor.raise_op import CheckAndRaise, assert_op from pytensor.raise_op import CheckAndRaise, assert_op
from pytensor.scalar import int32 from pytensor.scalar import int32
from pytensor.scalar.basic import ScalarConstant, ScalarVariable from pytensor.scalar.basic import ScalarConstant, ScalarType, ScalarVariable
from pytensor.tensor import ( from pytensor.tensor import (
_as_tensor_variable, _as_tensor_variable,
_get_vector_length, _get_vector_length,
...@@ -71,10 +71,10 @@ from pytensor.tensor.type import ( ...@@ -71,10 +71,10 @@ from pytensor.tensor.type import (
uint_dtypes, uint_dtypes,
values_eq_approx_always_true, values_eq_approx_always_true,
) )
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import ( from pytensor.tensor.variable import (
TensorConstant, TensorConstant,
TensorVariable, TensorVariable,
get_unique_constant_value,
) )
...@@ -319,6 +319,8 @@ def get_underlying_scalar_constant_value( ...@@ -319,6 +319,8 @@ def get_underlying_scalar_constant_value(
but I'm not sure where it is. but I'm not sure where it is.
""" """
from pytensor.compile.ops import DeepCopyOp, OutputGuard
v = orig_v v = orig_v
while True: while True:
if v is None: if v is None:
...@@ -336,34 +338,22 @@ def get_underlying_scalar_constant_value( ...@@ -336,34 +338,22 @@ def get_underlying_scalar_constant_value(
raise NotScalarConstantError() raise NotScalarConstantError()
if isinstance(v, Constant): if isinstance(v, Constant):
unique_value = get_unique_constant_value(v) if isinstance(v.type, TensorType) and v.unique_value is not None:
if unique_value is not None: return v.unique_value
data = unique_value
else:
data = v.data
if isinstance(data, np.ndarray): elif isinstance(v.type, ScalarType):
try: return v.data
return np.array(data.item(), dtype=v.dtype)
except ValueError:
raise NotScalarConstantError()
from pytensor.sparse.type import SparseTensorType elif isinstance(v.type, NoneTypeT):
return None
if isinstance(v.type, SparseTensorType):
raise NotScalarConstantError() raise NotScalarConstantError()
return data
if not only_process_constants and getattr(v, "owner", None) and max_recur > 0: if not only_process_constants and getattr(v, "owner", None) and max_recur > 0:
max_recur -= 1 max_recur -= 1
if isinstance( if isinstance(
v.owner.op, v.owner.op,
Alloc Alloc | DimShuffle | Unbroadcast | OutputGuard | DeepCopyOp,
| DimShuffle
| Unbroadcast
| compile.ops.OutputGuard
| compile.DeepCopyOp,
): ):
# OutputGuard is only used in debugmode but we # OutputGuard is only used in debugmode but we
# keep it here to avoid problems with old pickles # keep it here to avoid problems with old pickles
......
...@@ -41,7 +41,7 @@ from pytensor.tensor.rewriting.basic import ( ...@@ -41,7 +41,7 @@ from pytensor.tensor.rewriting.basic import (
register_specialize, register_specialize,
) )
from pytensor.tensor.shape import shape_padleft from pytensor.tensor.shape import shape_padleft
from pytensor.tensor.variable import TensorConstant, get_unique_constant_value from pytensor.tensor.variable import TensorConstant
class InplaceElemwiseOptimizer(GraphRewriter): class InplaceElemwiseOptimizer(GraphRewriter):
...@@ -513,7 +513,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node): ...@@ -513,7 +513,6 @@ def local_upcast_elemwise_constant_inputs(fgraph, node):
new_inputs.append(i) new_inputs.append(i)
else: else:
try: try:
# works only for scalars
cval_i = get_underlying_scalar_constant_value( cval_i = get_underlying_scalar_constant_value(
i, only_process_constants=True i, only_process_constants=True
) )
...@@ -1218,11 +1217,13 @@ def local_inline_composite_constants(fgraph, node): ...@@ -1218,11 +1217,13 @@ def local_inline_composite_constants(fgraph, node):
node.inputs, composite_op.fgraph.inputs, strict=True node.inputs, composite_op.fgraph.inputs, strict=True
): ):
# Complex variables don't have a `c_literal` that can be inlined # Complex variables don't have a `c_literal` that can be inlined
if "complex" not in outer_inp.type.dtype: if (
unique_value = get_unique_constant_value(outer_inp) isinstance(outer_inp, TensorConstant)
if unique_value is not None: and "complex" not in outer_inp.type.dtype
):
if outer_inp.unique_value is not None:
inner_replacements[inner_inp] = ps.constant( inner_replacements[inner_inp] = ps.constant(
unique_value, dtype=inner_inp.dtype outer_inp.unique_value, dtype=inner_inp.dtype
) )
continue continue
new_outer_inputs.append(outer_inp) new_outer_inputs.append(outer_inp)
......
...@@ -106,7 +106,6 @@ from pytensor.tensor.type import ( ...@@ -106,7 +106,6 @@ from pytensor.tensor.type import (
from pytensor.tensor.variable import ( from pytensor.tensor.variable import (
TensorConstant, TensorConstant,
TensorVariable, TensorVariable,
get_unique_constant_value,
) )
...@@ -138,16 +137,8 @@ def get_constant(v): ...@@ -138,16 +137,8 @@ def get_constant(v):
numeric constant. If v is a plain Variable, returns None. numeric constant. If v is a plain Variable, returns None.
""" """
if isinstance(v, Constant): if isinstance(v, TensorConstant):
unique_value = get_unique_constant_value(v) return v.unique_value
if unique_value is not None:
data = unique_value
else:
data = v.data
if data.ndim == 0:
return data
else:
return None
elif isinstance(v, Variable): elif isinstance(v, Variable):
return None return None
else: else:
...@@ -628,7 +619,14 @@ def local_mul_switch_sink(fgraph, node): ...@@ -628,7 +619,14 @@ def local_mul_switch_sink(fgraph, node):
# Look for a zero as the first or second branch of the switch # Look for a zero as the first or second branch of the switch
for branch in range(2): for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch] zero_switch_input = switch_node.inputs[1 + branch]
if not get_unique_constant_value(zero_switch_input) == 0.0: if (
not get_underlying_scalar_constant_value(
zero_switch_input,
only_process_constants=True,
raise_not_constant=False,
)
== 0.0
):
continue continue
switch_cond = switch_node.inputs[0] switch_cond = switch_node.inputs[0]
...@@ -685,7 +683,14 @@ def local_div_switch_sink(fgraph, node): ...@@ -685,7 +683,14 @@ def local_div_switch_sink(fgraph, node):
# Look for a zero as the first or second branch of the switch # Look for a zero as the first or second branch of the switch
for branch in range(2): for branch in range(2):
zero_switch_input = switch_node.inputs[1 + branch] zero_switch_input = switch_node.inputs[1 + branch]
if not get_unique_constant_value(zero_switch_input) == 0.0: if (
not get_underlying_scalar_constant_value(
zero_switch_input,
only_process_constants=True,
raise_not_constant=False,
)
== 0.0
):
continue continue
switch_cond = switch_node.inputs[0] switch_cond = switch_node.inputs[0]
......
...@@ -20,7 +20,7 @@ from pytensor.tensor import basic as ptb ...@@ -20,7 +20,7 @@ from pytensor.tensor import basic as ptb
from pytensor.tensor.elemwise import get_normalized_batch_axes from pytensor.tensor.elemwise import get_normalized_batch_axes
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor from pytensor.tensor.type import DenseTensorType, TensorType, int_dtypes, tensor
from pytensor.tensor.type_other import NoneConst from pytensor.tensor.type_other import NoneConst, NoneTypeT
from pytensor.tensor.variable import TensorConstant, TensorVariable from pytensor.tensor.variable import TensorConstant, TensorVariable
...@@ -401,8 +401,6 @@ class SpecifyShape(COp): ...@@ -401,8 +401,6 @@ class SpecifyShape(COp):
_output_type_depends_on_input_value = True _output_type_depends_on_input_value = True
def make_node(self, x, *shape): def make_node(self, x, *shape):
from pytensor.tensor.basic import get_underlying_scalar_constant_value
x = ptb.as_tensor_variable(x) x = ptb.as_tensor_variable(x)
shape = tuple( shape = tuple(
...@@ -428,11 +426,9 @@ class SpecifyShape(COp): ...@@ -428,11 +426,9 @@ class SpecifyShape(COp):
for i, (xts, s) in enumerate(zip(x.type.shape, shape, strict=True)): for i, (xts, s) in enumerate(zip(x.type.shape, shape, strict=True)):
if xts is not None: if xts is not None:
type_shape[i] = xts type_shape[i] = xts
else: elif not isinstance(s.type, NoneTypeT):
try: try:
type_s = get_underlying_scalar_constant_value(s) type_shape[i] = int(ptb.get_underlying_scalar_constant_value(s))
if type_s is not None:
type_shape[i] = int(type_s)
except NotScalarConstantError: except NotScalarConstantError:
pass pass
...@@ -460,22 +456,13 @@ class SpecifyShape(COp): ...@@ -460,22 +456,13 @@ class SpecifyShape(COp):
def infer_shape(self, fgraph, node, shapes): def infer_shape(self, fgraph, node, shapes):
xshape, *_ = shapes xshape, *_ = shapes
shape = node.inputs[1:] shape = node.inputs[1:]
new_shape = [] # Use x shape if specified dim is None, otherwise the specified shape
for dim in range(node.inputs[0].type.ndim): return [
s = shape[dim] [
try: xshape[i] if isinstance(dim.type, NoneTypeT) else dim
s = ptb.get_underlying_scalar_constant_value(s) for i, dim in enumerate(shape)
# We assume that `None` shapes are always retrieved by ]
# `get_underlying_scalar_constant_value`, and only in that case do we default to ]
# the shape of the input variable
if s is None:
s = xshape[dim]
except NotScalarConstantError:
pass
new_shape.append(ptb.as_tensor_variable(s))
assert len(new_shape) == len(xshape)
return [new_shape]
def connection_pattern(self, node): def connection_pattern(self, node):
return [[True], *[[False]] * len(node.inputs[1:])] return [[True], *[[False]] * len(node.inputs[1:])]
......
...@@ -11,7 +11,10 @@ from pytensor import tensor as pt ...@@ -11,7 +11,10 @@ from pytensor import tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.basic import Constant, OptionalApplyType, Variable from pytensor.graph.basic import Constant, OptionalApplyType, Variable
from pytensor.graph.utils import MetaType from pytensor.graph.utils import MetaType
from pytensor.scalar import ComplexError, IntegerDivisionError from pytensor.scalar import (
ComplexError,
IntegerDivisionError,
)
from pytensor.tensor import _get_vector_length from pytensor.tensor import _get_vector_length
from pytensor.tensor.exceptions import AdvancedIndexingError from pytensor.tensor.exceptions import AdvancedIndexingError
from pytensor.tensor.type import TensorType from pytensor.tensor.type import TensorType
...@@ -1042,17 +1045,9 @@ class TensorConstantSignature(tuple): ...@@ -1042,17 +1045,9 @@ class TensorConstantSignature(tuple):
def get_unique_constant_value(x: TensorVariable) -> Number | None: def get_unique_constant_value(x: TensorVariable) -> Number | None:
"""Return the unique value of a tensor, if there is one""" """Return the unique value of a tensor, if there is one"""
if isinstance(x, Constant): warnings.warn("get_unique_constant_value is deprecated.", FutureWarning)
data = x.data if isinstance(x, TensorConstant):
return x.unique_value
if isinstance(data, np.ndarray) and data.size > 0:
if data.size == 1:
return data.squeeze()
flat_data = data.ravel()
if (flat_data == flat_data[0]).all():
return flat_data[0]
return None return None
...@@ -1081,6 +1076,30 @@ class TensorConstant(TensorVariable, Constant[_TensorTypeType]): ...@@ -1081,6 +1076,30 @@ class TensorConstant(TensorVariable, Constant[_TensorTypeType]):
def signature(self): def signature(self):
return TensorConstantSignature((self.type, self.data)) return TensorConstantSignature((self.type, self.data))
@property
def unique_value(self) -> Number | None:
"""Return the unique value of a tensor, if there is one"""
try:
return self._unique_value
except AttributeError:
data = self.data
unique_value = None
if data.size > 0:
if data.size == 1:
unique_value = data.squeeze()
else:
flat_data = data.ravel()
if (flat_data == flat_data[0]).all():
unique_value = flat_data[0]
if unique_value is not None:
# Don't allow the unique value to be changed
unique_value.setflags(write=False)
self._unique_value = unique_value
return self._unique_value
def equals(self, other): def equals(self, other):
# Override Constant.equals to allow to compare with # Override Constant.equals to allow to compare with
# numpy.ndarray, and python type. # numpy.ndarray, and python type.
......
...@@ -3571,10 +3571,11 @@ class TestGetUnderlyingScalarConstantValue: ...@@ -3571,10 +3571,11 @@ class TestGetUnderlyingScalarConstantValue:
assert get_underlying_scalar_constant_value(s) == c.data assert get_underlying_scalar_constant_value(s) == c.data
def test_copy(self): def test_copy(self):
# Make sure we do not return the internal storage of a constant, # Make sure we do not return a writeable internal storage of a constant,
# so we cannot change the value of a constant by mistake. # so we cannot change the value of a constant by mistake.
c = constant(3) c = constant(3)
d = extract_constant(c) d = extract_constant(c)
with pytest.raises(ValueError, match="output array is read-only"):
d += 1 d += 1
e = extract_constant(c) e = extract_constant(c)
assert e == 3, (c, d, e) assert e == 3, (c, d, e)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论