提交 08c97f34 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add optional strict to Type.is_valid_value

上级 174117f9
...@@ -180,10 +180,10 @@ class Type(MetaObject): ...@@ -180,10 +180,10 @@ class Type(MetaObject):
return None return None
def is_valid_value(self, data: D) -> bool: def is_valid_value(self, data: D, strict: bool = True) -> bool:
"""Return ``True`` for any python object that would be a legal value for a `Variable` of this `Type`.""" """Return ``True`` for any python object that would be a legal value for a `Variable` of this `Type`."""
try: try:
self.filter(data, strict=True) self.filter(data, strict=strict)
return True return True
except (TypeError, ValueError): except (TypeError, ValueError):
return False return False
......
from typing import Generic, TypeVar
import numpy as np import numpy as np
import aesara import aesara
from aesara.graph.type import Type from aesara.graph.type import Type
T = TypeVar("T", np.random.RandomState, np.random.Generator)
gen_states_keys = { gen_states_keys = {
"MT19937": (["state"], ["key", "pos"]), "MT19937": (["state"], ["key", "pos"]),
"PCG64": (["state", "has_uint32", "uinteger"], ["state", "inc"]), "PCG64": (["state", "has_uint32", "uinteger"], ["state", "inc"]),
...@@ -18,22 +23,15 @@ gen_states_keys = { ...@@ -18,22 +23,15 @@ gen_states_keys = {
numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"} numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"}
class RandomType(Type): class RandomType(Type, Generic[T]):
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`.""" r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""
@classmethod
def filter(cls, data, strict=False, allow_downcast=None):
if cls.is_valid_value(data, strict):
return data
else:
raise TypeError()
@staticmethod @staticmethod
def may_share_memory(a, b): def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator return a._bit_generator is b._bit_generator
class RandomStateType(RandomType): class RandomStateType(RandomType[np.random.RandomState]):
r"""A Type wrapper for `numpy.random.RandomState`. r"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that The reason this exists (and `Generic` doesn't suffice) is that
...@@ -49,28 +47,38 @@ class RandomStateType(RandomType): ...@@ -49,28 +47,38 @@ class RandomStateType(RandomType):
def __repr__(self): def __repr__(self):
return "RandomStateType" return "RandomStateType"
@staticmethod def filter(self, data, strict: bool = False, allow_downcast=None):
def is_valid_value(a, strict): """
if isinstance(a, np.random.RandomState): XXX: This doesn't convert `data` to the same type of underlying RNG type
return True as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomStateType`.
In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.RandomState):
return data
if not strict and isinstance(a, dict): if not strict and isinstance(data, dict):
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"] gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
state_keys = ["key", "pos"] state_keys = ["key", "pos"]
for key in gen_keys: for key in gen_keys:
if key not in a: if key not in data:
return False raise TypeError()
for key in state_keys: for key in state_keys:
if key not in a["state"]: if key not in data["state"]:
return False raise TypeError()
state_key = a["state"]["key"] state_key = data["state"]["key"]
if state_key.shape == (624,) and state_key.dtype == np.uint32: if state_key.shape == (624,) and state_key.dtype == np.uint32:
return True # TODO: Add an option to convert to a `RandomState` instance?
return data
return False raise TypeError()
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b):
...@@ -114,7 +122,7 @@ aesara.compile.register_view_op_c_code( ...@@ -114,7 +122,7 @@ aesara.compile.register_view_op_c_code(
random_state_type = RandomStateType() random_state_type = RandomStateType()
class RandomGeneratorType(RandomType): class RandomGeneratorType(RandomType[np.random.Generator]):
r"""A Type wrapper for `numpy.random.Generator`. r"""A Type wrapper for `numpy.random.Generator`.
The reason this exists (and `Generic` doesn't suffice) is that The reason this exists (and `Generic` doesn't suffice) is that
...@@ -130,16 +138,25 @@ class RandomGeneratorType(RandomType): ...@@ -130,16 +138,25 @@ class RandomGeneratorType(RandomType):
def __repr__(self): def __repr__(self):
return "RandomGeneratorType" return "RandomGeneratorType"
@staticmethod def filter(self, data, strict=False, allow_downcast=None):
def is_valid_value(a, strict): """
if isinstance(a, np.random.Generator): XXX: This doesn't convert `data` to the same type of underlying RNG type
return True as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomGeneratorType`.
In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.Generator):
return data
if not strict and isinstance(a, dict): if not strict and isinstance(data, dict):
if "bit_generator" not in a: if "bit_generator" not in data:
return False raise TypeError()
else: else:
bit_gen_key = a["bit_generator"] bit_gen_key = data["bit_generator"]
if hasattr(bit_gen_key, "_value"): if hasattr(bit_gen_key, "_value"):
bit_gen_key = int(bit_gen_key._value) bit_gen_key = int(bit_gen_key._value)
...@@ -148,16 +165,16 @@ class RandomGeneratorType(RandomType): ...@@ -148,16 +165,16 @@ class RandomGeneratorType(RandomType):
gen_keys, state_keys = gen_states_keys[bit_gen_key] gen_keys, state_keys = gen_states_keys[bit_gen_key]
for key in gen_keys: for key in gen_keys:
if key not in a: if key not in data:
return False raise TypeError()
for key in state_keys: for key in state_keys:
if key not in a["state"]: if key not in data["state"]:
return False raise TypeError()
return True return data
return False raise TypeError()
@staticmethod @staticmethod
def values_eq(a, b): def values_eq(a, b):
......
...@@ -56,15 +56,17 @@ class TestRandomStateType: ...@@ -56,15 +56,17 @@ class TestRandomStateType:
with pytest.raises(TypeError): with pytest.raises(TypeError):
rng_type.filter(1) rng_type.filter(1)
rng = rng.get_state(legacy=False) rng_dict = rng.get_state(legacy=False)
assert rng_type.is_valid_value(rng, strict=False)
rng["state"] = {} assert rng_type.is_valid_value(rng_dict) is False
assert rng_type.is_valid_value(rng_dict, strict=False)
assert rng_type.is_valid_value(rng, strict=False) is False rng_dict["state"] = {}
rng = {} assert rng_type.is_valid_value(rng_dict, strict=False) is False
assert rng_type.is_valid_value(rng, strict=False) is False
rng_dict = {}
assert rng_type.is_valid_value(rng_dict, strict=False) is False
def test_values_eq(self): def test_values_eq(self):
...@@ -147,15 +149,17 @@ class TestRandomGeneratorType: ...@@ -147,15 +149,17 @@ class TestRandomGeneratorType:
with pytest.raises(TypeError): with pytest.raises(TypeError):
rng_type.filter(1) rng_type.filter(1)
rng = rng.__getstate__() rng_dict = rng.__getstate__()
assert rng_type.is_valid_value(rng, strict=False)
assert rng_type.is_valid_value(rng_dict) is False
assert rng_type.is_valid_value(rng_dict, strict=False)
rng["state"] = {} rng_dict["state"] = {}
assert rng_type.is_valid_value(rng, strict=False) is False assert rng_type.is_valid_value(rng_dict, strict=False) is False
rng = {} rng_dict = {}
assert rng_type.is_valid_value(rng, strict=False) is False assert rng_type.is_valid_value(rng_dict, strict=False) is False
def test_values_eq(self): def test_values_eq(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论