提交 2e35e6c4 authored 作者: qipengchen's avatar qipengchen 提交者: Brandon T. Willard

Use new JAX index update approach in AdvancedIncSubtensor

上级 7cca06f0
......@@ -648,9 +648,20 @@ _ = [jax_funcify.register(op, jax_funcify_IncSubtensor) for op in incsubtensor_o
def jax_funcify_AdvancedIncSubtensor(op, **kwargs):
if getattr(op, "set_instead_of_inc", False):
jax_fn = jax.ops.index_update
jax_fn = getattr(jax.ops, "index_update", None)
if jax_fn is None:
def jax_fn(x, indices, y):
return x.at[indices].set(y)
else:
jax_fn = jax.ops.index_add
jax_fn = getattr(jax.ops, "index_add", None)
if jax_fn is None:
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论