提交 f26a5086 authored 作者: kc611's avatar kc611 提交者: Thomas Wiecki

Add a test for JAX conversion of unimplemented RVs

上级 d127fc14
...@@ -192,9 +192,8 @@ def jax_typify(data, dtype): ...@@ -192,9 +192,8 @@ def jax_typify(data, dtype):
"""Convert instances of Aesara `Type`s to JAX types.""" """Convert instances of Aesara `Type`s to JAX types."""
if dtype is None: if dtype is None:
return data return data
if dtype is not None: else:
return jnp.array(data, dtype=dtype) return jnp.array(data, dtype=dtype)
raise NotImplementedError(f"No JAX conversion for data and dtype: {data}, {dtype}")
@jax_typify.register(np.ndarray) @jax_typify.register(np.ndarray)
......
...@@ -29,7 +29,7 @@ from aesara.tensor.math import clip, cosh, gammaln, log ...@@ -29,7 +29,7 @@ from aesara.tensor.math import clip, cosh, gammaln, log
from aesara.tensor.math import max as aet_max from aesara.tensor.math import max as aet_max
from aesara.tensor.math import maximum, prod from aesara.tensor.math import maximum, prod
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.random.basic import normal from aesara.tensor.random.basic import RandomVariable, normal
from aesara.tensor.random.utils import RandomStream from aesara.tensor.random.utils import RandomStream
from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, reshape from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, reshape
from aesara.tensor.type import ( from aesara.tensor.type import (
...@@ -978,6 +978,28 @@ def test_random(): ...@@ -978,6 +978,28 @@ def test_random():
compare_jax_and_py(fgraph, []) compare_jax_and_py(fgraph, [])
def test_random_unimplemented():
class NonExistentRV(RandomVariable):
name = "non-existent"
ndim_supp = 0
ndims_params = []
dtype = "floatX"
def __call__(self, size=None, **kwargs):
return super().__call__(size=size, **kwargs)
def rng_fn(cls, rng, size):
return 0
nonexistentrv = NonExistentRV()
rng = shared(np.random.RandomState(123))
out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
with pytest.raises(NotImplementedError):
compare_jax_and_py(fgraph, [])
def test_RandomStream(): def test_RandomStream():
srng = RandomStream(seed=123) srng = RandomStream(seed=123)
out = srng.normal() - srng.normal() out = srng.normal() - srng.normal()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论