提交 9ab8df51 authored 作者: Etienne Duchesne's avatar Etienne Duchesne 提交者: Ricardo Vieira

Simplify dispatch of JAX random variables

上级 95ce102d
...@@ -796,7 +796,7 @@ def test_random_custom_implementation(): ...@@ -796,7 +796,7 @@ def test_random_custom_implementation():
@jax_sample_fn.register(CustomRV) @jax_sample_fn.register(CustomRV)
def jax_sample_fn_custom(op, node): def jax_sample_fn_custom(op, node):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
return (rng, 0) return 0
return sample_fn return sample_fn
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论