提交 6157b651 authored 作者: unknown's avatar unknown 提交者: Brandon T. Willard

Replace TensorConstant.tag.unique_value with a get_unique_value function

上级 b3f686f7
......@@ -52,7 +52,7 @@ from aesara.tensor.subtensor import (
get_idx_list,
set_subtensor,
)
from aesara.tensor.var import TensorConstant
from aesara.tensor.var import TensorConstant, get_unique_value
_logger = logging.getLogger("aesara.scan.opt")
......@@ -118,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
node_inp = node.inputs[idx + 1]
if (
isinstance(node_inp, TensorConstant)
and node_inp.tag.unique_value is not None
and get_unique_value(node_inp) is not None
):
try:
# This works if input is a constant that has all entries
......
......@@ -60,7 +60,7 @@ from aesara.tensor.type import (
uint_dtypes,
values_eq_approx_always_true,
)
from aesara.tensor.var import TensorConstant, TensorVariable
from aesara.tensor.var import TensorConstant, TensorVariable, get_unique_value
_logger = logging.getLogger("aesara.tensor.basic")
......@@ -323,8 +323,9 @@ def get_scalar_constant_value(
raise NotScalarConstantError()
if isinstance(v, Constant):
if getattr(v.tag, "unique_value", None) is not None:
data = v.tag.unique_value
unique_value = get_unique_value(v)
if unique_value is not None:
data = unique_value
else:
data = v.data
......
......@@ -92,7 +92,7 @@ from aesara.tensor.type import (
values_eq_approx_remove_inf_nan,
values_eq_approx_remove_nan,
)
from aesara.tensor.var import TensorConstant
from aesara.tensor.var import TensorConstant, get_unique_value
from aesara.utils import NoDuplicateOptWarningFilter
......@@ -129,8 +129,9 @@ def get_constant(v):
"""
if isinstance(v, Constant):
if getattr(v.tag, "unique_value", None) is not None:
data = v.tag.unique_value
unique_value = get_unique_value(v)
if unique_value is not None:
data = unique_value
else:
data = v.data
if data.ndim == 0:
......
......@@ -2,6 +2,8 @@ import copy
import traceback as tb
import warnings
from collections.abc import Iterable
from numbers import Number
from typing import Optional
import numpy as np
......@@ -957,6 +959,20 @@ class TensorConstantSignature(tuple):
no_nan = property(_get_no_nan)
def get_unique_value(x: TensorVariable) -> Optional[Number]:
"""Return the unique value of a tensor, if there is one"""
if isinstance(x, Constant):
data = x.data
if isinstance(data, np.ndarray) and data.ndim > 0:
flat_data = data.ravel()
if flat_data.shape[0]:
if (flat_data == flat_data[0]).all():
return flat_data[0]
return None
class TensorConstant(TensorVariable, Constant):
"""Subclass to add the tensor operators to the basic `Constant` class.
......@@ -966,16 +982,11 @@ class TensorConstant(TensorVariable, Constant):
def __init__(self, type, data, name=None):
Constant.__init__(self, type, data, name)
self.tag.unique_value = None
if isinstance(data, np.ndarray) and data.ndim > 0:
flat_data = data.ravel()
if flat_data.shape[0]:
if (flat_data == flat_data[0]).all():
self.tag.unique_value = flat_data[0]
def __str__(self):
if self.tag.unique_value is not None:
name = f"{self.data.shape} of {self.tag.unique_value}"
unique_val = get_unique_value(self)
if unique_val is not None:
name = f"{self.data.shape} of {unique_val}"
else:
name = f"{self.data}"
if len(name) > 20:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论