提交 2e35e6c4 authored 作者: qipengchen's avatar qipengchen 提交者: Brandon T. Willard

Use new JAX index update approach in AdvancedIncSubtensor

上级 7cca06f0
...@@ -648,9 +648,20 @@ _ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_o ...@@ -648,9 +648,20 @@ _ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_o
def jax_funcify_AdvancedIncSubtensor(op, **kwargs): def jax_funcify_AdvancedIncSubtensor(op, **kwargs):
if getattr(op, "set_instead_of_inc", False): if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update jax_fn = getattr(jax.ops, "index_update", None)
if jax_fn is None:
def jax_fn(x, indices, y):
return x.at[indices].set(y)
else: else:
jax_fn = jax.ops.index_add jax_fn = getattr(jax.ops, "index_add", None)
if jax_fn is None:
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn): def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y) return jax_fn(x, ilist, y)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论