提交 b3f12b26 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Implement `TensorFromScalar` as a pass-through

上级 47f76375
...@@ -182,7 +182,7 @@ def jax_funcify_MakeVector(op, **kwargs): ...@@ -182,7 +182,7 @@ def jax_funcify_MakeVector(op, **kwargs):
@jax_funcify.register(TensorFromScalar) @jax_funcify.register(TensorFromScalar)
def jax_funcify_TensorFromScalar(op, **kwargs): def jax_funcify_TensorFromScalar(op, **kwargs):
def tensor_from_scalar(x): def tensor_from_scalar(x):
return jnp.array(x) return x
return tensor_from_scalar return tensor_from_scalar
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论