提交 b6266896 authored 作者: Adrien Corenflos's avatar Adrien Corenflos 提交者: Thomas Wiecki

Split RNG keys before using them in JAX backend

上级 5c63ee70
...@@ -125,8 +125,9 @@ def jax_sample_fn_generic(op): ...@@ -125,8 +125,9 @@ def jax_sample_fn_generic(op):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
sample = jax_op(rng_key, *parameters, shape=size, dtype=dtype) rng_key, sampling_key = jax.random.split(rng_key, 2)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0] sample = jax_op(sampling_key, *parameters, shape=size, dtype=dtype)
rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
return sample_fn return sample_fn
...@@ -151,9 +152,10 @@ def jax_sample_fn_loc_scale(op): ...@@ -151,9 +152,10 @@ def jax_sample_fn_loc_scale(op):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters loc, scale = parameters
sample = loc + jax_op(rng_key, size, dtype) * scale sample = loc + jax_op(sampling_key, size, dtype) * scale
rng["jax_state"] = jax.random.split(rng_key, num=1)[0] rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
return sample_fn return sample_fn
...@@ -168,8 +170,9 @@ def jax_sample_fn_no_dtype(op): ...@@ -168,8 +170,9 @@ def jax_sample_fn_no_dtype(op):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
sample = jax_op(rng_key, *parameters, shape=size) rng_key, sampling_key = jax.random.split(rng_key, 2)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0] sample = jax_op(sampling_key, *parameters, shape=size)
rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
return sample_fn return sample_fn
...@@ -189,9 +192,12 @@ def jax_sample_fn_uniform(op): ...@@ -189,9 +192,12 @@ def jax_sample_fn_uniform(op):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
minval, maxval = parameters minval, maxval = parameters
sample = jax_op(rng_key, shape=size, dtype=dtype, minval=minval, maxval=maxval) sample = jax_op(
rng["jax_state"] = jax.random.split(rng_key, num=1)[0] sampling_key, shape=size, dtype=dtype, minval=minval, maxval=maxval
)
rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
return sample_fn return sample_fn
...@@ -211,9 +217,10 @@ def jax_sample_fn_shape_rate(op): ...@@ -211,9 +217,10 @@ def jax_sample_fn_shape_rate(op):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(shape, rate) = parameters (shape, rate) = parameters
sample = jax_op(rng_key, shape, size, dtype) / rate sample = jax_op(sampling_key, shape, size, dtype) / rate
rng["jax_state"] = jax.random.split(rng_key, num=1)[0] rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
return sample_fn return sample_fn
...@@ -225,9 +232,10 @@ def jax_sample_fn_exponential(op): ...@@ -225,9 +232,10 @@ def jax_sample_fn_exponential(op):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(scale,) = parameters (scale,) = parameters
sample = jax.random.exponential(rng_key, size, dtype) * scale sample = jax.random.exponential(sampling_key, size, dtype) * scale
rng["jax_state"] = jax.random.split(rng_key, num=1)[0] rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
return sample_fn return sample_fn
...@@ -239,13 +247,14 @@ def jax_sample_fn_t(op): ...@@ -239,13 +247,14 @@ def jax_sample_fn_t(op):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
( (
df, df,
loc, loc,
scale, scale,
) = parameters ) = parameters
sample = loc + jax.random.t(rng_key, df, size, dtype) * scale sample = loc + jax.random.t(sampling_key, df, size, dtype) * scale
rng["jax_state"] = jax.random.split(rng_key, num=1)[0] rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
return sample_fn return sample_fn
...@@ -257,9 +266,10 @@ def jax_funcify_choice(op): ...@@ -257,9 +266,10 @@ def jax_funcify_choice(op):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(a, p, replace) = parameters (a, p, replace) = parameters
smpl_value = jax.random.choice(rng_key, a, size, replace, p) smpl_value = jax.random.choice(sampling_key, a, size, replace, p)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0] rng["jax_state"] = rng_key
return (rng, smpl_value) return (rng, smpl_value)
return sample_fn return sample_fn
...@@ -271,9 +281,10 @@ def jax_sample_fn_permutation(op): ...@@ -271,9 +281,10 @@ def jax_sample_fn_permutation(op):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
(x,) = parameters (x,) = parameters
sample = jax.random.permutation(rng_key, x) sample = jax.random.permutation(sampling_key, x)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0] rng["jax_state"] = rng_key
return (rng, sample) return (rng, sample)
return sample_fn return sample_fn
...@@ -285,10 +296,11 @@ def jax_sample_fn_lognormal(op): ...@@ -285,10 +296,11 @@ def jax_sample_fn_lognormal(op):
def sample_fn(rng, size, dtype, *parameters): def sample_fn(rng, size, dtype, *parameters):
rng_key = rng["jax_state"] rng_key = rng["jax_state"]
rng_key, sampling_key = jax.random.split(rng_key, 2)
loc, scale = parameters loc, scale = parameters
sample = loc + jax.random.normal(rng_key, size, dtype) * scale sample = loc + jax.random.normal(sampling_key, size, dtype) * scale
sample_exp = jax.numpy.exp(sample) sample_exp = jax.numpy.exp(sample)
rng["jax_state"] = jax.random.split(rng_key, num=1)[0] rng["jax_state"] = rng_key
return (rng, sample_exp) return (rng, sample_exp)
return sample_fn return sample_fn
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论