提交 84af50df authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove use of numpy_scalar in get_scalar_constant_value

上级 f6760ad7
...@@ -37,7 +37,7 @@ from aesara.tensor import ( ...@@ -37,7 +37,7 @@ from aesara.tensor import (
get_vector_length, get_vector_length,
) )
from aesara.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise from aesara.tensor.elemwise import DimShuffle, Elemwise, scalar_elemwise
from aesara.tensor.exceptions import EmptyConstantError, NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.shape import ( from aesara.tensor.shape import (
Shape, Shape,
Shape_i, Shape_i,
...@@ -257,31 +257,6 @@ def _obj_is_wrappable_as_tensor(x): ...@@ -257,31 +257,6 @@ def _obj_is_wrappable_as_tensor(x):
return False return False
def numpy_scalar(data):
"""Return a scalar stored in a numpy ndarray.
Raises
------
NotScalarConstantError
If the numpy ndarray is not a scalar.
EmptyConstantError
"""
# handle case where data is numpy.array([])
if data.ndim > 0 and (len(data.shape) == 0 or builtins.max(data.shape) == 0):
assert np.all(np.array([]) == data)
raise EmptyConstantError()
try:
complex(data) # works for all numeric scalars
return data
except Exception:
raise NotScalarConstantError(
"v.data is non-numeric, non-scalar, or has more than one" " unique value",
data,
)
_scalar_constant_value_elemwise_ops = ( _scalar_constant_value_elemwise_ops = (
aes.Cast, aes.Cast,
aes.Switch, aes.Switch,
...@@ -344,7 +319,10 @@ def get_scalar_constant_value( ...@@ -344,7 +319,10 @@ def get_scalar_constant_value(
return np.asarray(v) return np.asarray(v)
if isinstance(v, np.ndarray): if isinstance(v, np.ndarray):
return numpy_scalar(v).copy() try:
return np.array(v.item(), dtype=v.dtype)
except ValueError:
raise NotScalarConstantError()
if isinstance(v, Constant): if isinstance(v, Constant):
if getattr(v.tag, "unique_value", None) is not None: if getattr(v.tag, "unique_value", None) is not None:
...@@ -353,7 +331,10 @@ def get_scalar_constant_value( ...@@ -353,7 +331,10 @@ def get_scalar_constant_value(
data = v.data data = v.data
if isinstance(data, np.ndarray): if isinstance(data, np.ndarray):
return numpy_scalar(data).copy() try:
return np.array(data.item(), dtype=v.dtype)
except ValueError:
raise NotScalarConstantError()
else: else:
return data return data
...@@ -575,7 +556,7 @@ def get_scalar_constant_value( ...@@ -575,7 +556,7 @@ def get_scalar_constant_value(
if isinstance(grandparent, Constant): if isinstance(grandparent, Constant):
return np.asarray(np.shape(grandparent.data)[idx]) return np.asarray(np.shape(grandparent.data)[idx])
raise NotScalarConstantError(v) raise NotScalarConstantError()
class TensorFromScalar(Op): class TensorFromScalar(Op):
......
...@@ -9,13 +9,6 @@ class NotScalarConstantError(Exception): ...@@ -9,13 +9,6 @@ class NotScalarConstantError(Exception):
""" """
class EmptyConstantError(NotScalarConstantError):
"""
Raised by get_scalar_const_value if called on something that is a
zero dimensional constant.
"""
class AdvancedIndexingError(TypeError): class AdvancedIndexingError(TypeError):
""" """
Raised when Subtensor is asked to perform advanced indexing. Raised when Subtensor is asked to perform advanced indexing.
......
...@@ -87,7 +87,7 @@ from aesara.tensor.basic import ( ...@@ -87,7 +87,7 @@ from aesara.tensor.basic import (
zeros_like, zeros_like,
) )
from aesara.tensor.elemwise import DimShuffle from aesara.tensor.elemwise import DimShuffle
from aesara.tensor.exceptions import EmptyConstantError, NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.math import dense_dot, eq from aesara.tensor.math import dense_dot, eq
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright from aesara.tensor.shape import Reshape, Shape, Shape_i, shape_padright
...@@ -3140,6 +3140,15 @@ def test_dimshuffle_duplicate(): ...@@ -3140,6 +3140,15 @@ def test_dimshuffle_duplicate():
class TestGetScalarConstantValue: class TestGetScalarConstantValue:
def test_basic(self): def test_basic(self):
res = get_scalar_constant_value(aet.as_tensor(10))
assert res == 10
assert isinstance(res, np.ndarray)
res = get_scalar_constant_value(np.array(10))
assert res == 10
assert isinstance(res, np.ndarray)
a = aet.stack([1, 2, 3]) a = aet.stack([1, 2, 3])
assert get_scalar_constant_value(a[0]) == 1 assert get_scalar_constant_value(a[0]) == 1
assert get_scalar_constant_value(a[1]) == 2 assert get_scalar_constant_value(a[1]) == 2
...@@ -3195,7 +3204,7 @@ class TestGetScalarConstantValue: ...@@ -3195,7 +3204,7 @@ class TestGetScalarConstantValue:
assert get_scalar_constant_value(np.array(3)) == 3 assert get_scalar_constant_value(np.array(3)) == 3
with pytest.raises(NotScalarConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(np.array([0, 1])) get_scalar_constant_value(np.array([0, 1]))
with pytest.raises(EmptyConstantError): with pytest.raises(NotScalarConstantError):
get_scalar_constant_value(np.array([])) get_scalar_constant_value(np.array([]))
def test_make_vector(self): def test_make_vector(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论