提交 6d5cc513 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Rename get_unique_value to get_unique_constant_value

上级 5c87d741
...@@ -69,7 +69,7 @@ from pytensor.tensor.subtensor import ( ...@@ -69,7 +69,7 @@ from pytensor.tensor.subtensor import (
get_slice_elements, get_slice_elements,
set_subtensor, set_subtensor,
) )
from pytensor.tensor.var import TensorConstant, get_unique_value from pytensor.tensor.var import TensorConstant, get_unique_constant_value
list_opt_slice = [ list_opt_slice = [
...@@ -136,7 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -136,7 +136,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
node_inp = node.inputs[idx + 1] node_inp = node.inputs[idx + 1]
if ( if (
isinstance(node_inp, TensorConstant) isinstance(node_inp, TensorConstant)
and get_unique_value(node_inp) is not None and get_unique_constant_value(node_inp) is not None
): ):
try: try:
# This works if input is a constant that has all entries # This works if input is a constant that has all entries
......
...@@ -62,7 +62,11 @@ from pytensor.tensor.type import ( ...@@ -62,7 +62,11 @@ from pytensor.tensor.type import (
uint_dtypes, uint_dtypes,
values_eq_approx_always_true, values_eq_approx_always_true,
) )
from pytensor.tensor.var import TensorConstant, TensorVariable, get_unique_value from pytensor.tensor.var import (
TensorConstant,
TensorVariable,
get_unique_constant_value,
)
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -323,7 +327,7 @@ def get_underlying_scalar_constant_value( ...@@ -323,7 +327,7 @@ def get_underlying_scalar_constant_value(
raise NotScalarConstantError() raise NotScalarConstantError()
if isinstance(v, Constant): if isinstance(v, Constant):
unique_value = get_unique_value(v) unique_value = get_unique_constant_value(v)
if unique_value is not None: if unique_value is not None:
data = unique_value data = unique_value
else: else:
......
...@@ -101,7 +101,7 @@ from pytensor.tensor.type import ( ...@@ -101,7 +101,7 @@ from pytensor.tensor.type import (
values_eq_approx_remove_inf_nan, values_eq_approx_remove_inf_nan,
values_eq_approx_remove_nan, values_eq_approx_remove_nan,
) )
from pytensor.tensor.var import TensorConstant, get_unique_value from pytensor.tensor.var import TensorConstant, get_unique_constant_value
def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False): def scalarconsts_rest(inputs, elemwise=True, only_process_constants=False):
...@@ -133,7 +133,7 @@ def get_constant(v): ...@@ -133,7 +133,7 @@ def get_constant(v):
""" """
if isinstance(v, Constant): if isinstance(v, Constant):
unique_value = get_unique_value(v) unique_value = get_unique_constant_value(v)
if unique_value is not None: if unique_value is not None:
data = unique_value data = unique_value
else: else:
......
...@@ -986,7 +986,7 @@ class TensorConstantSignature(tuple): ...@@ -986,7 +986,7 @@ class TensorConstantSignature(tuple):
return self._no_nan return self._no_nan
def get_unique_value(x: TensorVariable) -> Optional[Number]: def get_unique_constant_value(x: TensorVariable) -> Optional[Number]:
"""Return the unique value of a tensor, if there is one""" """Return the unique value of a tensor, if there is one"""
if isinstance(x, Constant): if isinstance(x, Constant):
data = x.data data = x.data
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论