提交 efb4996f authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Ricardo Vieira

Document discrepancy with `Clip` and `jax.numpy.clip`

上级 ff609353
...@@ -56,6 +56,14 @@ def jax_funcify_Identity(op, **kwargs): ...@@ -56,6 +56,14 @@ def jax_funcify_Identity(op, **kwargs):
@jax_funcify.register(Clip) @jax_funcify.register(Clip)
def jax_funcify_Clip(op, **kwargs): def jax_funcify_Clip(op, **kwargs):
"""Register the translation for the `Clip` `Op`.
PyTensor's `Clip` operator operates differently from NumPy's when the
specified `min` is larger than the `max` so we cannot reuse `jax.numpy.clip`
to maintain consistency with PyTensor.
"""
def clip(x, min, max): def clip(x, min, max):
return jnp.where(x < min, min, jnp.where(x > max, max, x)) return jnp.where(x < min, min, jnp.where(x > max, max, x))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论