提交 6157b651 authored 作者: unknown's avatar unknown 提交者: Brandon T. Willard

Replace TensorConstant.tag.unique_value with a get_unique_value function

上级 b3f686f7
...@@ -52,7 +52,7 @@ from aesara.tensor.subtensor import ( ...@@ -52,7 +52,7 @@ from aesara.tensor.subtensor import (
get_idx_list, get_idx_list,
set_subtensor, set_subtensor,
) )
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant, get_unique_value
_logger = logging.getLogger("aesara.scan.opt") _logger = logging.getLogger("aesara.scan.opt")
...@@ -118,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node): ...@@ -118,7 +118,7 @@ def remove_constants_and_unused_inputs_scan(fgraph, node):
node_inp = node.inputs[idx + 1] node_inp = node.inputs[idx + 1]
if ( if (
isinstance(node_inp, TensorConstant) isinstance(node_inp, TensorConstant)
and node_inp.tag.unique_value is not None and get_unique_value(node_inp) is not None
): ):
try: try:
# This works if input is a constant that has all entries # This works if input is a constant that has all entries
......
...@@ -60,7 +60,7 @@ from aesara.tensor.type import ( ...@@ -60,7 +60,7 @@ from aesara.tensor.type import (
uint_dtypes, uint_dtypes,
values_eq_approx_always_true, values_eq_approx_always_true,
) )
from aesara.tensor.var import TensorConstant, TensorVariable from aesara.tensor.var import TensorConstant, TensorVariable, get_unique_value
_logger = logging.getLogger("aesara.tensor.basic") _logger = logging.getLogger("aesara.tensor.basic")
...@@ -323,8 +323,9 @@ def get_scalar_constant_value( ...@@ -323,8 +323,9 @@ def get_scalar_constant_value(
raise NotScalarConstantError() raise NotScalarConstantError()
if isinstance(v, Constant): if isinstance(v, Constant):
if getattr(v.tag, "unique_value", None) is not None: unique_value = get_unique_value(v)
data = v.tag.unique_value if unique_value is not None:
data = unique_value
else: else:
data = v.data data = v.data
......
...@@ -92,7 +92,7 @@ from aesara.tensor.type import ( ...@@ -92,7 +92,7 @@ from aesara.tensor.type import (
values_eq_approx_remove_inf_nan, values_eq_approx_remove_inf_nan,
values_eq_approx_remove_nan, values_eq_approx_remove_nan,
) )
from aesara.tensor.var import TensorConstant from aesara.tensor.var import TensorConstant, get_unique_value
from aesara.utils import NoDuplicateOptWarningFilter from aesara.utils import NoDuplicateOptWarningFilter
...@@ -129,8 +129,9 @@ def get_constant(v): ...@@ -129,8 +129,9 @@ def get_constant(v):
""" """
if isinstance(v, Constant): if isinstance(v, Constant):
if getattr(v.tag, "unique_value", None) is not None: unique_value = get_unique_value(v)
data = v.tag.unique_value if unique_value is not None:
data = unique_value
else: else:
data = v.data data = v.data
if data.ndim == 0: if data.ndim == 0:
......
...@@ -2,6 +2,8 @@ import copy ...@@ -2,6 +2,8 @@ import copy
import traceback as tb import traceback as tb
import warnings import warnings
from collections.abc import Iterable from collections.abc import Iterable
from numbers import Number
from typing import Optional
import numpy as np import numpy as np
...@@ -957,6 +959,20 @@ class TensorConstantSignature(tuple): ...@@ -957,6 +959,20 @@ class TensorConstantSignature(tuple):
no_nan = property(_get_no_nan) no_nan = property(_get_no_nan)
def get_unique_value(x: TensorVariable) -> Optional[Number]:
"""Return the unique value of a tensor, if there is one"""
if isinstance(x, Constant):
data = x.data
if isinstance(data, np.ndarray) and data.ndim > 0:
flat_data = data.ravel()
if flat_data.shape[0]:
if (flat_data == flat_data[0]).all():
return flat_data[0]
return None
class TensorConstant(TensorVariable, Constant): class TensorConstant(TensorVariable, Constant):
"""Subclass to add the tensor operators to the basic `Constant` class. """Subclass to add the tensor operators to the basic `Constant` class.
...@@ -966,16 +982,11 @@ class TensorConstant(TensorVariable, Constant): ...@@ -966,16 +982,11 @@ class TensorConstant(TensorVariable, Constant):
def __init__(self, type, data, name=None): def __init__(self, type, data, name=None):
Constant.__init__(self, type, data, name) Constant.__init__(self, type, data, name)
self.tag.unique_value = None
if isinstance(data, np.ndarray) and data.ndim > 0:
flat_data = data.ravel()
if flat_data.shape[0]:
if (flat_data == flat_data[0]).all():
self.tag.unique_value = flat_data[0]
def __str__(self): def __str__(self):
if self.tag.unique_value is not None: unique_val = get_unique_value(self)
name = f"{self.data.shape} of {self.tag.unique_value}" if unique_val is not None:
name = f"{self.data.shape} of {unique_val}"
else: else:
name = f"{self.data}" name = f"{self.data}"
if len(name) > 20: if len(name) > 20:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论