Unverified 提交 17fa8b13 authored 作者: Habeeb Shopeju's avatar Habeeb Shopeju 提交者: GitHub

PyTorch Softmax Ops (#846)

上级 f3d2ede9
......@@ -9,7 +9,7 @@ channels:
dependencies:
- python>=3.10
- compilers
- numpy>=1.17.0
- numpy>=1.17.0,<2
- scipy>=0.14,<1.14.0
- filelock
- etuples
......
......@@ -2,6 +2,7 @@ import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@pytorch_funcify.register(Elemwise)
......@@ -34,3 +35,52 @@ def pytorch_funcify_DimShuffle(op, **kwargs):
return res
return dimshuffle
@pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis
dtype = kwargs["node"].inputs[0].dtype
if not dtype.startswith("float"):
raise NotImplementedError(
"Pytorch Softmax is not currently implemented for non-float types."
)
def softmax(x):
if axis is not None:
return torch.softmax(x, dim=axis)
else:
return torch.softmax(x.ravel(), dim=0).reshape(x.shape)
return softmax
@pytorch_funcify.register(LogSoftmax)
def pytorch_funcify_LogSoftmax(op, **kwargs):
axis = op.axis
dtype = kwargs["node"].inputs[0].dtype
if not dtype.startswith("float"):
raise NotImplementedError(
"Pytorch LogSoftmax is not currently implemented for non-float types."
)
def log_softmax(x):
if axis is not None:
return torch.log_softmax(x, dim=axis)
else:
return torch.log_softmax(x.ravel(), dim=0).reshape(x.shape)
return log_softmax
@pytorch_funcify.register(SoftmaxGrad)
def jax_funcify_SoftmaxGrad(op, **kwargs):
axis = op.axis
def softmax_grad(dy, sm):
dy_times_sm = dy * sm
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm
return softmax_grad
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
......@@ -53,3 +55,50 @@ def test_pytorch_elemwise():
fg = FunctionGraph([x], [out])
compare_pytorch_and_py(fg, [[0.9, 0.9]])
@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis, dtype):
x = matrix("x", dtype=dtype)
out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)
if dtype == "int64":
with pytest.raises(
NotImplementedError,
match="Pytorch Softmax is not currently implemented for non-float types.",
):
compare_pytorch_and_py(fgraph, [test_input])
else:
compare_pytorch_and_py(fgraph, [test_input])
@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis, dtype):
x = matrix("x", dtype=dtype)
out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)
if dtype == "int64":
with pytest.raises(
NotImplementedError,
match="Pytorch LogSoftmax is not currently implemented for non-float types.",
):
compare_pytorch_and_py(fgraph, [test_input])
else:
compare_pytorch_and_py(fgraph, [test_input])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis):
dy = matrix("dy")
dy_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm")
sm_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_pytorch_and_py(fgraph, [dy_value, sm_value])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论