Unverified 提交 781073b6 authored 作者: Harshvir Sandhu's avatar Harshvir Sandhu 提交者: GitHub

Add docs on implementing PyTorch Ops (and CumOp) (#837)

* Add Pytorch support for Cum Op * Modify test for Cum op * Raise TypeError if axis not int or None * Fix init method of CumOp * Extend tutorial on documentation for Pytorch * Add tab for Numba * Add intersphinx mapping * Parametrize dtype
上级 e57e25bf
...@@ -32,8 +32,16 @@ extensions = [ ...@@ -32,8 +32,16 @@ extensions = [
"sphinx.ext.napoleon", "sphinx.ext.napoleon",
"sphinx.ext.linkcode", "sphinx.ext.linkcode",
"sphinx.ext.mathjax", "sphinx.ext.mathjax",
"sphinx_design",
"sphinx.ext.intersphinx"
] ]
intersphinx_mapping = {
"jax": ("https://jax.readthedocs.io/en/latest", None),
"numpy": ("https://numpy.org/doc/stable", None),
"torch": ("https://pytorch.org/docs/stable", None),
}
needs_sphinx = "3" needs_sphinx = "3"
todo_include_todos = True todo_include_todos = True
......
...@@ -13,6 +13,7 @@ dependencies: ...@@ -13,6 +13,7 @@ dependencies:
- mock - mock
- pillow - pillow
- pymc-sphinx-theme - pymc-sphinx-theme
- sphinx-design
- pip - pip
- pip: - pip:
- -e .. - -e ..
...@@ -4,4 +4,5 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify ...@@ -4,4 +4,5 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
# # Load dispatch specializations # # Load dispatch specializations
import pytensor.link.pytorch.dispatch.scalar import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise import pytensor.link.pytorch.dispatch.elemwise
import pytensor.link.pytorch.dispatch.extra_ops
# isort: on # isort: on
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.extra_ops import CumOp
@pytorch_funcify.register(CumOp)
def pytorch_funcify_Cumop(op, **kwargs):
axis = op.axis
mode = op.mode
def cumop(x):
if axis is None:
x = x.reshape(-1)
dim = 0
else:
dim = axis
if mode == "add":
return torch.cumsum(x, dim=dim)
else:
return torch.cumprod(x, dim=dim)
return cumop
...@@ -283,6 +283,8 @@ class CumOp(COp): ...@@ -283,6 +283,8 @@ class CumOp(COp):
def __init__(self, axis: int | None = None, mode="add"): def __init__(self, axis: int | None = None, mode="add"):
if mode not in ("add", "mul"): if mode not in ("add", "mul"):
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"') raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"')
if not (isinstance(axis, int) or axis is None):
raise TypeError("axis must be an integer or None.")
self.axis = axis self.axis = axis
self.mode = mode self.mode = mode
......
import numpy as np
import pytest
import pytensor.tensor as pt
from pytensor.graph import FunctionGraph
from tests.link.pytorch.test_basic import compare_pytorch_and_py
@pytest.mark.parametrize(
"dtype",
["float64", "int64"],
)
@pytest.mark.parametrize(
"axis",
[None, 1, (0,)],
)
def test_pytorch_CumOp(axis, dtype):
"""Test PyTorch conversion of the `CumOp` `Op`."""
# Create a symbolic input for the first input of `CumOp`
a = pt.matrix("a", dtype=dtype)
# Create test value
test_value = np.arange(9, dtype=dtype).reshape((3, 3))
# Create the output variable
if isinstance(axis, tuple):
with pytest.raises(TypeError, match="axis must be an integer or None."):
out = pt.cumsum(a, axis=axis)
with pytest.raises(TypeError, match="axis must be an integer or None."):
out = pt.cumprod(a, axis=axis)
else:
out = pt.cumsum(a, axis=axis)
# Create a PyTensor `FunctionGraph`
fgraph = FunctionGraph([a], [out])
# Pass the graph and inputs to the testing function
compare_pytorch_and_py(fgraph, [test_value])
# For the second mode of CumOp
out = pt.cumprod(a, axis=axis)
fgraph = FunctionGraph([a], [out])
compare_pytorch_and_py(fgraph, [test_value])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论