提交 4235ccc3 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Typify 0-dim arrays to corresponding number

上级 0087e562
...@@ -30,6 +30,8 @@ def jax_typify(data, dtype=None, **kwargs): ...@@ -30,6 +30,8 @@ def jax_typify(data, dtype=None, **kwargs):
@jax_typify.register(np.ndarray) @jax_typify.register(np.ndarray)
def jax_typify_ndarray(data, dtype=None, **kwargs): def jax_typify_ndarray(data, dtype=None, **kwargs):
if len(data.shape) == 0:
return data.item()
return jnp.array(data, dtype=dtype) return jnp.array(data, dtype=dtype)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论