提交 84af50df authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove use of numpy_scalar in get_scalar_constant_value

上级 f6760ad7
......@@ -37,7 +37,7 @@ from aesara.tensor import (
get_vector_length,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
from aesara.tensor.exceptions import EmptyConstantError, NotScalarConstantError
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.shape import (
Shape,
Shape_i,
......@@ -257,31 +257,6 @@ def _obj_is_wrappable_as_tensor(x):
return False
def numpy_scalar(data):
"""Return a scalar stored in a numpy ndarray.
Raises
------
NotScalarConstantError
If the numpy ndarray is not a scalar.
EmptyConstantError
"""
# handle case where data is numpy.array([])
if data.ndim > 0 and (len(data.shape) == 0 or builtins.max(data.shape) == 0):
assert np.all(np.array([]) == data)
raise EmptyConstantError()
try:
complex(data) # works for all numeric scalars
return data
except Exception:
raise NotScalarConstantError(
"v.data is non-numeric, non-scalar, or has more than one" " unique value",
data,
)
_scalar_constant_value_elemwise_ops = (
aes.Cast,
aes.Switch,
......@@ -344,7 +319,10 @@ def get_scalar_constant_value(
return np.asarray(v)
if isinstance(v, np.ndarray):
return numpy_scalar(v).copy()
try:
return np.array(v.item(), dtype=v.dtype)
except ValueError:
raise NotScalarConstantError()
if isinstance(v, Constant):
if getattr(v.tag, "unique_value", None) is not None:
......@@ -353,7 +331,10 @@ def get_scalar_constant_value(
data = v.data
if isinstance(data, np.ndarray):
return numpy_scalar(data).copy()
try:
return np.array(data.item(), dtype=v.dtype)
except ValueError:
raise NotScalarConstantError()
else:
return data
......@@ -575,7 +556,7 @@ def get_scalar_constant_value(
if isinstance(grandparent, Constant):
return np.asarray(np.shape(grandparent.data)[idx])
raise NotScalarConstantError(v)
raise NotScalarConstantError()
class TensorFromScalar(Op):
......
......@@ -9,13 +9,6 @@ class NotScalarConstantError(Exception):
"""
class EmptyConstantError(NotScalarConstantError):
"""
Raised by get_scalar_const_value if called on something that is a
zero dimensional constant.
"""
class AdvancedIndexingError(TypeError):
"""
Raised when Subtensor is asked to perform advanced indexing.
......
......@@ -87,7 +87,7 @@ from aesara.tensor.basic import (
zeros_like,
)
from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import EmptyConstantError, NotScalarConstantError
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import dense_dot, eq
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright
......@@ -3140,6 +3140,15 @@ def test_dimshuffle_duplicate():
class TestGetScalarConstantValue:
def test_basic(self):
res = get_scalar_constant_value(aet.as_tensor(10))
assert res == 10
assert isinstance(res, np.ndarray)
res = get_scalar_constant_value(np.array(10))
assert res == 10
assert isinstance(res, np.ndarray)
a = aet.stack([1, 2, 3])
assert get_scalar_constant_value(a[0]) == 1
assert get_scalar_constant_value(a[1]) == 2
......@@ -3195,7 +3204,7 @@ class TestGetScalarConstantValue:
assert get_scalar_constant_value(np.array(3)) == 3
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(np.array([0, 1]))
with pytest.raises(EmptyConstantError):
with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(np.array([]))
def test_make_vector(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论