提交 605b609f authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use new JAX index update approach when available

上级 94f5ddfd
...@@ -616,9 +616,20 @@ def jax_funcify_IncSubtensor(op, **kwargs): ...@@ -616,9 +616,20 @@ def jax_funcify_IncSubtensor(op, **kwargs):
idx_list = getattr(op, "idx_list", None) idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update jax_fn = getattr(jax.ops, "index_update", None)
if jax_fn is None:
def jax_fn(x, indices, y):
return x.at[indices].set(y)
else: else:
jax_fn = jax.ops.index_add jax_fn = getattr(jax.ops, "index_add", None)
if jax_fn is None:
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list): def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list) indices = indices_from_subtensor(ilist, idx_list)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论