提交 7d07260b authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove RandomType's get_shape_info and get_size methods

上级 d9fd640e
import sys
import numpy as np import numpy as np
import aesara import aesara
...@@ -30,10 +28,6 @@ class RandomType(Type): ...@@ -30,10 +28,6 @@ class RandomType(Type):
else: else:
raise TypeError() raise TypeError()
@staticmethod
def get_shape_info(obj):
return obj.get_value(borrow=True)
@staticmethod @staticmethod
def may_share_memory(a, b): def may_share_memory(a, b):
return a._bit_generator is b._bit_generator return a._bit_generator is b._bit_generator
...@@ -99,10 +93,6 @@ class RandomStateType(RandomType): ...@@ -99,10 +93,6 @@ class RandomStateType(RandomType):
return _eq(sa, sb) return _eq(sa, sb)
@staticmethod
def get_size(shape_info):
return sys.getsizeof(shape_info.get_state(legacy=False))
# Register `RandomStateType`'s C code for `ViewOp`. # Register `RandomStateType`'s C code for `ViewOp`.
aesara.compile.register_view_op_c_code( aesara.compile.register_view_op_c_code(
...@@ -184,11 +174,6 @@ class RandomGeneratorType(RandomType): ...@@ -184,11 +174,6 @@ class RandomGeneratorType(RandomType):
return _eq(sa, sb) return _eq(sa, sb)
@staticmethod
def get_size(shape_info):
state = shape_info.__getstate__()
return sys.getsizeof(state)
# Register `RandomGeneratorType`'s C code for `ViewOp`. # Register `RandomGeneratorType`'s C code for `ViewOp`.
aesara.compile.register_view_op_c_code( aesara.compile.register_view_op_c_code(
......
import pickle import pickle
import sys
import numpy as np import numpy as np
import pytest import pytest
...@@ -95,21 +94,6 @@ class TestRandomStateType: ...@@ -95,21 +94,6 @@ class TestRandomStateType:
assert not rng_type.values_eq(rng_g, rng_a) assert not rng_type.values_eq(rng_g, rng_a)
assert not rng_type.values_eq(rng_e, rng_g) assert not rng_type.values_eq(rng_e, rng_g)
def test_get_shape_info(self):
rng = np.random.RandomState(12)
rng_a = shared(rng)
assert isinstance(
random_state_type.get_shape_info(rng_a), np.random.RandomState
)
def test_get_size(self):
rng = np.random.RandomState(12)
rng_a = shared(rng)
shape_info = random_state_type.get_shape_info(rng_a)
size = random_state_type.get_size(shape_info)
assert size == sys.getsizeof(rng.get_state(legacy=False))
def test_may_share_memory(self): def test_may_share_memory(self):
bg1 = np.random.MT19937() bg1 = np.random.MT19937()
bg2 = np.random.MT19937() bg2 = np.random.MT19937()
...@@ -119,16 +103,23 @@ class TestRandomStateType: ...@@ -119,16 +103,23 @@ class TestRandomStateType:
rng_var_a = shared(rng_a, borrow=True) rng_var_a = shared(rng_a, borrow=True)
rng_var_b = shared(rng_b, borrow=True) rng_var_b = shared(rng_b, borrow=True)
shape_info_a = random_state_type.get_shape_info(rng_var_a)
shape_info_b = random_state_type.get_shape_info(rng_var_b)
assert random_state_type.may_share_memory(shape_info_a, shape_info_b) is False assert (
random_state_type.may_share_memory(
rng_var_a.get_value(borrow=True), rng_var_b.get_value(borrow=True)
)
is False
)
rng_c = np.random.RandomState(bg2) rng_c = np.random.RandomState(bg2)
rng_var_c = shared(rng_c, borrow=True) rng_var_c = shared(rng_c, borrow=True)
shape_info_c = random_state_type.get_shape_info(rng_var_c)
assert random_state_type.may_share_memory(shape_info_b, shape_info_c) is True assert (
random_state_type.may_share_memory(
rng_var_b.get_value(borrow=True), rng_var_c.get_value(borrow=True)
)
is True
)
class TestRandomGeneratorType: class TestRandomGeneratorType:
...@@ -197,21 +188,6 @@ class TestRandomGeneratorType: ...@@ -197,21 +188,6 @@ class TestRandomGeneratorType:
assert rng_type.is_valid_value(bitgen_g, strict=True) assert rng_type.is_valid_value(bitgen_g, strict=True)
assert rng_type.is_valid_value(bitgen_h.__getstate__(), strict=False) assert rng_type.is_valid_value(bitgen_h.__getstate__(), strict=False)
def test_get_shape_info(self):
rng = np.random.default_rng(12)
rng_a = shared(rng)
assert isinstance(
random_generator_type.get_shape_info(rng_a), np.random.Generator
)
def test_get_size(self):
rng = np.random.Generator(np.random.PCG64(12))
rng_a = shared(rng)
shape_info = random_generator_type.get_shape_info(rng_a)
size = random_generator_type.get_size(shape_info)
assert size == sys.getsizeof(rng.__getstate__())
def test_may_share_memory(self): def test_may_share_memory(self):
bg_a = np.random.PCG64() bg_a = np.random.PCG64()
bg_b = np.random.PCG64() bg_b = np.random.PCG64()
...@@ -220,13 +196,20 @@ class TestRandomGeneratorType: ...@@ -220,13 +196,20 @@ class TestRandomGeneratorType:
rng_var_a = shared(rng_a, borrow=True) rng_var_a = shared(rng_a, borrow=True)
rng_var_b = shared(rng_b, borrow=True) rng_var_b = shared(rng_b, borrow=True)
shape_info_a = random_state_type.get_shape_info(rng_var_a)
shape_info_b = random_state_type.get_shape_info(rng_var_b)
assert random_state_type.may_share_memory(shape_info_a, shape_info_b) is False assert (
random_state_type.may_share_memory(
rng_var_a.get_value(borrow=True), rng_var_b.get_value(borrow=True)
)
is False
)
rng_c = np.random.Generator(bg_b) rng_c = np.random.Generator(bg_b)
rng_var_c = shared(rng_c, borrow=True) rng_var_c = shared(rng_c, borrow=True)
shape_info_c = random_state_type.get_shape_info(rng_var_c)
assert random_state_type.may_share_memory(shape_info_b, shape_info_c) is True assert (
random_state_type.may_share_memory(
rng_var_b.get_value(borrow=True), rng_var_c.get_value(borrow=True)
)
is True
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论