Unverified 提交 9ac65794 authored 作者: Chris Fonnesbeck's avatar Chris Fonnesbeck 提交者: GitHub

Added PyTorch clip dispatch (#1797)

上级 6319fac8
......@@ -5,6 +5,7 @@ import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
Cast,
Clip,
Invert,
ScalarOp,
)
......@@ -71,6 +72,14 @@ def pytorch_funcify_Softplus(op, node, **kwargs):
return torch.nn.Softplus()
@pytorch_funcify.register(Clip)
def pytorch_funcify_Clip(op, node, **kwargs):
def clip(x, min_val, max_val):
return torch.where(x < min_val, min_val, torch.where(x > max_val, max_val, x))
return clip
@pytorch_funcify.register(ScalarLoop)
def pytorch_funicify_ScalarLoop(op, node, **kwargs):
update = pytorch_funcify(op.fgraph, **kwargs)
......
......@@ -151,6 +151,19 @@ def test_cast():
assert res.dtype == np.int32
@pytest.mark.parametrize(
"x_val, min_val, max_val",
[
(np.array([5.0], dtype=config.floatX), 0.0, 10.0),
(np.array([-5.0], dtype=config.floatX), 0.0, 10.0),
],
)
def test_clip(x_val, min_val, max_val):
x = pt.tensor("x", shape=x_val.shape, dtype=config.floatX)
out = pt.clip(x, min_val, max_val)
compare_pytorch_and_py([x], [out], [x_val])
def test_vmap_elemwise():
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论