Unverified 提交 17fa8b13 authored 作者: Habeeb Shopeju's avatar Habeeb Shopeju 提交者: GitHub

PyTorch Softmax Ops (#846)

上级 f3d2ede9
...@@ -9,7 +9,7 @@ channels: ...@@ -9,7 +9,7 @@ channels:
dependencies: dependencies:
- python>=3.10 - python>=3.10
- compilers - compilers
- numpy>=1.17.0 - numpy>=1.17.0,<2
- scipy>=0.14,<1.14.0 - scipy>=0.14,<1.14.0
- filelock - filelock
- etuples - etuples
......
...@@ -2,6 +2,7 @@ import torch ...@@ -2,6 +2,7 @@ import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise from pytensor.tensor.elemwise import DimShuffle, Elemwise
from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@pytorch_funcify.register(Elemwise) @pytorch_funcify.register(Elemwise)
...@@ -34,3 +35,52 @@ def pytorch_funcify_DimShuffle(op, **kwargs): ...@@ -34,3 +35,52 @@ def pytorch_funcify_DimShuffle(op, **kwargs):
return res return res
return dimshuffle return dimshuffle
@pytorch_funcify.register(Softmax)
def pytorch_funcify_Softmax(op, **kwargs):
axis = op.axis
dtype = kwargs["node"].inputs[0].dtype
if not dtype.startswith("float"):
raise NotImplementedError(
"Pytorch Softmax is not currently implemented for non-float types."
)
def softmax(x):
if axis is not None:
return torch.softmax(x, dim=axis)
else:
return torch.softmax(x.ravel(), dim=0).reshape(x.shape)
return softmax
@pytorch_funcify.register(LogSoftmax)
def pytorch_funcify_LogSoftmax(op, **kwargs):
axis = op.axis
dtype = kwargs["node"].inputs[0].dtype
if not dtype.startswith("float"):
raise NotImplementedError(
"Pytorch LogSoftmax is not currently implemented for non-float types."
)
def log_softmax(x):
if axis is not None:
return torch.log_softmax(x, dim=axis)
else:
return torch.log_softmax(x.ravel(), dim=0).reshape(x.shape)
return log_softmax
@pytorch_funcify.register(SoftmaxGrad)
def jax_funcify_SoftmaxGrad(op, **kwargs):
axis = op.axis
def softmax_grad(dy, sm):
dy_times_sm = dy * sm
return dy_times_sm - torch.sum(dy_times_sm, dim=axis, keepdim=True) * sm
return softmax_grad
import numpy as np import numpy as np
import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py from tests.link.pytorch.test_basic import compare_pytorch_and_py
...@@ -53,3 +55,50 @@ def test_pytorch_elemwise(): ...@@ -53,3 +55,50 @@ def test_pytorch_elemwise():
fg = FunctionGraph([x], [out]) fg = FunctionGraph([x], [out])
compare_pytorch_and_py(fg, [[0.9, 0.9]]) compare_pytorch_and_py(fg, [[0.9, 0.9]])
@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis, dtype):
x = matrix("x", dtype=dtype)
out = softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)
if dtype == "int64":
with pytest.raises(
NotImplementedError,
match="Pytorch Softmax is not currently implemented for non-float types.",
):
compare_pytorch_and_py(fgraph, [test_input])
else:
compare_pytorch_and_py(fgraph, [test_input])
@pytest.mark.parametrize("dtype", ["float64", "int64"])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis, dtype):
x = matrix("x", dtype=dtype)
out = log_softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
test_input = np.arange(6, dtype=config.floatX).reshape(2, 3)
if dtype == "int64":
with pytest.raises(
NotImplementedError,
match="Pytorch LogSoftmax is not currently implemented for non-float types.",
):
compare_pytorch_and_py(fgraph, [test_input])
else:
compare_pytorch_and_py(fgraph, [test_input])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis):
dy = matrix("dy")
dy_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm")
sm_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_pytorch_and_py(fgraph, [dy_value, sm_value])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论