Unverified 提交 7f4e0ab4 authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: GitHub

Add a JAX implementation for `FillDiagonal`

上级 22465044
...@@ -956,20 +956,11 @@ def jax_funcify_Bartlett(op, **kwargs): ...@@ -956,20 +956,11 @@ def jax_funcify_Bartlett(op, **kwargs):
@jax_funcify.register(FillDiagonal) @jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op, **kwargs): def jax_funcify_FillDiagonal(op, **kwargs):
def filldiagonal(value, diagonal):
i, j = jnp.diag_indices(min(value.shape[-2:]))
return value.at[..., i, j].set(diagonal)
# def filldiagonal(a, val): return filldiagonal
# if a.ndim == 2:
# step = a.shape[1] + 1
# end = a.shape[1] * a.shape[1]
# a.flat[:end:step] = val
# else:
# jnp.fill_diagonal(a, val)
#
# return a
#
# return filldiagonal
raise NotImplementedError("flatiter not implemented in JAX")
@jax_funcify.register(FillDiagonalOffset) @jax_funcify.register(FillDiagonalOffset)
......
...@@ -1205,7 +1205,6 @@ def test_extra_ops(): ...@@ -1205,7 +1205,6 @@ def test_extra_ops():
c = at.as_tensor(5) c = at.as_tensor(5)
with pytest.raises(NotImplementedError):
out = at_extra_ops.fill_diagonal(a, c) out = at_extra_ops.fill_diagonal(a, c)
fgraph = FunctionGraph([a], [out]) fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论