提交 08c97f34 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add optional strict to Type.is_valid_value

上级 174117f9
......@@ -180,10 +180,10 @@ class Type(MetaObject):
return None
def is_valid_value(self, data: D) -> bool:
def is_valid_value(self, data: D, strict: bool = True) -> bool:
"""Return ``True`` for any python object that would be a legal value for a `Variable` of this `Type`."""
try:
self.filter(data, strict=True)
self.filter(data, strict=strict)
return True
except (TypeError, ValueError):
return False
......
from typing import Generic, TypeVar
import numpy as np
import aesara
from aesara.graph.type import Type
T = TypeVar("T", np.random.RandomState, np.random.Generator)
gen_states_keys = {
"MT19937": (["state"], ["key", "pos"]),
"PCG64": (["state", "has_uint32", "uinteger"], ["state", "inc"]),
......@@ -18,22 +23,15 @@ gen_states_keys = {
numpy_bit_gens = {0: "MT19937", 1: "PCG64", 2: "Philox", 3: "SFC64"}
class RandomType(Type):
class RandomType(Type, Generic[T]):
r"""A Type wrapper for `numpy.random.Generator` and `numpy.random.RandomState`."""
@classmethod
def filter(cls, data, strict=False, allow_downcast=None):
if cls.is_valid_value(data, strict):
return data
else:
raise TypeError()
@staticmethod
def may_share_memory(a, b):
def may_share_memory(a: T, b: T):
return a._bit_generator is b._bit_generator
class RandomStateType(RandomType):
class RandomStateType(RandomType[np.random.RandomState]):
r"""A Type wrapper for `numpy.random.RandomState`.
The reason this exists (and `Generic` doesn't suffice) is that
......@@ -49,28 +47,38 @@ class RandomStateType(RandomType):
def __repr__(self):
return "RandomStateType"
@staticmethod
def is_valid_value(a, strict):
if isinstance(a, np.random.RandomState):
return True
def filter(self, data, strict: bool = False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomStateType`.
In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.RandomState):
return data
if not strict and isinstance(a, dict):
if not strict and isinstance(data, dict):
gen_keys = ["bit_generator", "gauss", "has_gauss", "state"]
state_keys = ["key", "pos"]
for key in gen_keys:
if key not in a:
return False
if key not in data:
raise TypeError()
for key in state_keys:
if key not in a["state"]:
return False
if key not in data["state"]:
raise TypeError()
state_key = a["state"]["key"]
state_key = data["state"]["key"]
if state_key.shape == (624,) and state_key.dtype == np.uint32:
return True
# TODO: Add an option to convert to a `RandomState` instance?
return data
return False
raise TypeError()
@staticmethod
def values_eq(a, b):
......@@ -114,7 +122,7 @@ aesara.compile.register_view_op_c_code(
random_state_type = RandomStateType()
class RandomGeneratorType(RandomType):
class RandomGeneratorType(RandomType[np.random.Generator]):
r"""A Type wrapper for `numpy.random.Generator`.
The reason this exists (and `Generic` doesn't suffice) is that
......@@ -130,16 +138,25 @@ class RandomGeneratorType(RandomType):
def __repr__(self):
return "RandomGeneratorType"
@staticmethod
def is_valid_value(a, strict):
if isinstance(a, np.random.Generator):
return True
def filter(self, data, strict=False, allow_downcast=None):
"""
XXX: This doesn't convert `data` to the same type of underlying RNG type
as `self`. It really only checks that `data` is of the appropriate type
to be a valid `RandomGeneratorType`.
In other words, it serves as a `Type.is_valid_value` implementation,
but, because the default `Type.is_valid_value` depends on
`Type.filter`, we need to have it here to avoid surprising circular
dependencies in sub-classes.
"""
if isinstance(data, np.random.Generator):
return data
if not strict and isinstance(a, dict):
if "bit_generator" not in a:
return False
if not strict and isinstance(data, dict):
if "bit_generator" not in data:
raise TypeError()
else:
bit_gen_key = a["bit_generator"]
bit_gen_key = data["bit_generator"]
if hasattr(bit_gen_key, "_value"):
bit_gen_key = int(bit_gen_key._value)
......@@ -148,16 +165,16 @@ class RandomGeneratorType(RandomType):
gen_keys, state_keys = gen_states_keys[bit_gen_key]
for key in gen_keys:
if key not in a:
return False
if key not in data:
raise TypeError()
for key in state_keys:
if key not in a["state"]:
return False
if key not in data["state"]:
raise TypeError()
return True
return data
return False
raise TypeError()
@staticmethod
def values_eq(a, b):
......
......@@ -56,15 +56,17 @@ class TestRandomStateType:
with pytest.raises(TypeError):
rng_type.filter(1)
rng = rng.get_state(legacy=False)
assert rng_type.is_valid_value(rng, strict=False)
rng_dict = rng.get_state(legacy=False)
rng["state"] = {}
assert rng_type.is_valid_value(rng_dict) is False
assert rng_type.is_valid_value(rng_dict, strict=False)
assert rng_type.is_valid_value(rng, strict=False) is False
rng_dict["state"] = {}
rng = {}
assert rng_type.is_valid_value(rng, strict=False) is False
assert rng_type.is_valid_value(rng_dict, strict=False) is False
rng_dict = {}
assert rng_type.is_valid_value(rng_dict, strict=False) is False
def test_values_eq(self):
......@@ -147,15 +149,17 @@ class TestRandomGeneratorType:
with pytest.raises(TypeError):
rng_type.filter(1)
rng = rng.__getstate__()
assert rng_type.is_valid_value(rng, strict=False)
rng_dict = rng.__getstate__()
assert rng_type.is_valid_value(rng_dict) is False
assert rng_type.is_valid_value(rng_dict, strict=False)
rng["state"] = {}
rng_dict["state"] = {}
assert rng_type.is_valid_value(rng, strict=False) is False
assert rng_type.is_valid_value(rng_dict, strict=False) is False
rng = {}
assert rng_type.is_valid_value(rng, strict=False) is False
rng_dict = {}
assert rng_type.is_valid_value(rng_dict, strict=False) is False
def test_values_eq(self):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论