Unverified 提交 7f4e0ab4 authored 作者: Maxim Kochurov's avatar Maxim Kochurov 提交者: GitHub

Add a JAX implementation for `FillDiagonal`

上级 22465044
......@@ -956,20 +956,11 @@ def jax_funcify_Bartlett(op, **kwargs):
@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op, **kwargs):
def filldiagonal(value, diagonal):
i, j = jnp.diag_indices(min(value.shape[-2:]))
return value.at[..., i, j].set(diagonal)
# def filldiagonal(a, val):
# if a.ndim == 2:
# step = a.shape[1] + 1
# end = a.shape[1] * a.shape[1]
# a.flat[:end:step] = val
# else:
# jnp.fill_diagonal(a, val)
#
# return a
#
# return filldiagonal
raise NotImplementedError("flatiter not implemented in JAX")
return filldiagonal
@jax_funcify.register(FillDiagonalOffset)
......
......@@ -1205,10 +1205,9 @@ def test_extra_ops():
c = at.as_tensor(5)
with pytest.raises(NotImplementedError):
out = at_extra_ops.fill_diagonal(a, c)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = at_extra_ops.fill_diagonal(a, c)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = at_extra_ops.fill_diagonal_offset(a, c, c)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论