Unverified 提交 320bac49 authored 作者: Harshvir Sandhu's avatar Harshvir Sandhu 提交者: GitHub

Add initial support for PyTorch backend (#764)

上级 efa845a3
...@@ -76,6 +76,7 @@ jobs: ...@@ -76,6 +76,7 @@ jobs:
float32: [0, 1] float32: [0, 1]
install-numba: [0] install-numba: [0]
install-jax: [0] install-jax: [0]
install-torch: [0]
part: part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse" - "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan" - "tests/scan"
...@@ -116,6 +117,11 @@ jobs: ...@@ -116,6 +117,11 @@ jobs:
fast-compile: 0 fast-compile: 0
float32: 0 float32: 0
part: "tests/link/jax" part: "tests/link/jax"
- install-torch: 1
python-version: "3.10"
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
...@@ -142,9 +148,12 @@ jobs: ...@@ -142,9 +148,12 @@ jobs:
- name: Install dependencies - name: Install dependencies
shell: micromamba-shell {0} shell: micromamba-shell {0}
run: | run: |
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi
pip install -e ./ pip install -e ./
micromamba list && pip freeze micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
...@@ -153,6 +162,7 @@ jobs: ...@@ -153,6 +162,7 @@ jobs:
PYTHON_VERSION: ${{ matrix.python-version }} PYTHON_VERSION: ${{ matrix.python-version }}
INSTALL_NUMBA: ${{ matrix.install-numba }} INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }} INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
- name: Run tests - name: Run tests
shell: micromamba-shell {0} shell: micromamba-shell {0}
...@@ -199,7 +209,7 @@ jobs: ...@@ -199,7 +209,7 @@ jobs:
- name: Install dependencies - name: Install dependencies
shell: micromamba-shell {0} shell: micromamba-shell {0}
run: | run: |
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
pip install -e ./ pip install -e ./
micromamba list && pip freeze micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))' python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
...@@ -268,3 +278,4 @@ jobs: ...@@ -268,3 +278,4 @@ jobs:
directory: ./coverage/ directory: ./coverage/
fail_ci_if_error: true fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
...@@ -28,6 +28,7 @@ from pytensor.link.basic import Linker, PerformLinker ...@@ -28,6 +28,7 @@ from pytensor.link.basic import Linker, PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker from pytensor.link.jax.linker import JAXLinker
from pytensor.link.numba.linker import NumbaLinker from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker from pytensor.link.vm import VMLinker
...@@ -47,6 +48,7 @@ predefined_linkers = { ...@@ -47,6 +48,7 @@ predefined_linkers = {
"vm_nogc": VMLinker(allow_gc=False, use_cloop=False), "vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True), "cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(), "jax": JAXLinker(),
"pytorch": PytorchLinker(),
"numba": NumbaLinker(), "numba": NumbaLinker(),
} }
...@@ -460,6 +462,18 @@ JAX = Mode( ...@@ -460,6 +462,18 @@ JAX = Mode(
], ],
), ),
) )
PYTORCH = Mode(
PytorchLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
],
),
)
NUMBA = Mode( NUMBA = Mode(
NumbaLinker(), NumbaLinker(),
RewriteDatabaseQuery( RewriteDatabaseQuery(
...@@ -474,6 +488,7 @@ predefined_modes = { ...@@ -474,6 +488,7 @@ predefined_modes = {
"FAST_RUN": FAST_RUN, "FAST_RUN": FAST_RUN,
"JAX": JAX, "JAX": JAX,
"NUMBA": NUMBA, "NUMBA": NUMBA,
"PYTORCH": PYTORCH,
} }
instantiated_default_mode = None instantiated_default_mode = None
......
...@@ -600,6 +600,10 @@ class JITLinker(PerformLinker): ...@@ -600,6 +600,10 @@ class JITLinker(PerformLinker):
def jit_compile(self, fn: Callable) -> Callable: def jit_compile(self, fn: Callable) -> Callable:
"""JIT compile a converted ``FunctionGraph``.""" """JIT compile a converted ``FunctionGraph``."""
def input_filter(self, inp: Any) -> Any:
"""Apply a filter to the data input."""
return inp
def output_filter(self, var: Variable, out: Any) -> Any: def output_filter(self, var: Variable, out: Any) -> Any:
"""Apply a filter to the data output by a JITed function call.""" """Apply a filter to the data output by a JITed function call."""
return out return out
...@@ -657,7 +661,7 @@ class JITLinker(PerformLinker): ...@@ -657,7 +661,7 @@ class JITLinker(PerformLinker):
thunk_inputs=thunk_inputs, thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs, thunk_outputs=thunk_outputs,
): ):
outputs = fgraph_jit(*[x[0] for x in thunk_inputs]) outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs): for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_var][0] = True compute_map[o_var][0] = True
......
# isort: off
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
# # Load dispatch specializations
import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise
# isort: on
from functools import singledispatch
import torch
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
return torch.as_tensor(data, dtype=dtype)
@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
raise NotImplementedError(
f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation"
)
@pytorch_funcify.register(FunctionGraph)
def pytorch_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="pytorch_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
pytorch_funcify,
type_conversion_fn=pytorch_typify,
fgraph_name=fgraph_name,
**kwargs,
)
@pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs):
error = op.exc_type
msg = op.msg
def assert_fn(x, *conditions):
for cond in conditions:
if not cond.item():
raise error(msg)
return x
return assert_fn
@pytorch_funcify.register(DeepCopyOp)
def pytorch_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return x.clone()
return deepcopyop
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)
return elemwise_fn
@pytorch_funcify.register(DimShuffle)
def pytorch_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = torch.permute(x, op.transposition)
shape = list(res.shape[: len(op.shuffle)])
for augm in op.augment:
shape.insert(augm, 1)
res = torch.reshape(res, shape)
if not op.inplace:
res = res.clone()
return res
return dimshuffle
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
ScalarOp,
)
@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does,
even though it's dispatched on the Scalar Op.
"""
nfunc_spec = getattr(op, "nfunc_spec", None)
if nfunc_spec is None:
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
func_name = nfunc_spec[0]
pytorch_func = getattr(torch, func_name)
if len(node.inputs) > op.nfunc_spec[1]:
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
# even though the base Op from `func_name` is specified as a binary Op.
# This happens with `Add`, which can work as a `Sum` for multiple scalars.
pytorch_variadic_func = getattr(torch, op.nfunc_variadic, None)
if not pytorch_variadic_func:
raise NotImplementedError(
f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs"
)
def pytorch_func(*args):
return pytorch_variadic_func(
torch.stack(torch.broadcast_tensors(*args), axis=0), axis=0
)
return pytorch_func
from typing import Any
from pytensor.graph.basic import Variable
from pytensor.link.basic import JITLinker
class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""
def input_filter(self, inp: Any) -> Any:
from pytensor.link.pytorch.dispatch import pytorch_typify
return pytorch_typify(inp)
def output_filter(self, var: Variable, out: Any) -> Any:
return out.cpu()
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify
return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
)
def jit_compile(self, fn):
import torch
return torch.compile(fn)
def create_thunk_inputs(self, storage_map):
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
thunk_inputs.append(sinput)
return thunk_inputs
from collections.abc import Callable, Iterable
from functools import partial
import numpy as np
import pytest
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.type import scalar, vector
torch = pytest.importorskip("torch")
pytorch_mode = get_mode("PYTORCH")
py_mode = get_mode("FAST_COMPILE")
def compare_pytorch_and_py(
fgraph: FunctionGraph,
test_inputs: Iterable,
assert_fn: Callable | None = None,
must_be_device_array: bool = True,
pytorch_mode=pytorch_mode,
py_mode=py_mode,
):
"""Function to compare python graph output and pytorch compiled output for testing equality
Parameters
----------
fgraph: FunctionGraph
PyTensor function Graph object
test_inputs: iter
Numerical inputs for testing the function graph
assert_fn: func, opt
Assert function used to check for equality between python and pytorch. If not
provided uses np.testing.assert_allclose
must_be_device_array: Bool
Checks if torch.device.type is cuda
"""
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode)
pytorch_res = pytensor_torch_fn(*test_inputs)
if must_be_device_array:
if isinstance(pytorch_res, list):
assert all(isinstance(res, torch.Tensor) for res in pytorch_res)
else:
assert pytorch_res.device.type == "cuda"
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)
if len(fgraph.outputs) > 1:
for j, p in zip(pytorch_res, py_res):
assert_fn(j.cpu(), p)
else:
assert_fn([pytorch_res[0].cpu()], py_res)
return pytensor_torch_fn, pytorch_res
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_pytorch_FunctionGraph_once(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")
"""Make sure that an output is only computed once when it's referenced multiple times."""
from pytensor.link.pytorch.dispatch import pytorch_funcify
with torch.device(device):
x = vector("x")
y = vector("y")
class TestOp(Op):
def __init__(self):
self.called = 0
def make_node(self, *args):
return Apply(self, list(args), [x.type() for x in args])
def perform(self, inputs, outputs):
for i, inp in enumerate(inputs):
outputs[i][0] = inp[0]
@pytorch_funcify.register(TestOp)
def pytorch_funcify_TestOp(op, **kwargs):
def func(*args, op=op):
op.called += 1
for arg in args:
assert arg.device.type == device
return list(args)
return func
op1 = TestOp()
op2 = TestOp()
q, r = op1(x, y)
outs = op2(q + r, q + r)
out_fg = FunctionGraph([x, y], outs, clone=False)
assert len(out_fg.outputs) == 2
out_torch = pytorch_funcify(out_fg)
x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX))
y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX))
res = out_torch(x_val, y_val)
for output in res:
assert torch.equal(
output, torch.tensor([3, 5]).to(getattr(torch, config.floatX))
)
assert len(res) == 2
assert op1.called == 1
assert op2.called == 1
res = out_torch(x_val, y_val)
for output in res:
assert torch.equal(
output, torch.tensor([3, 5]).to(getattr(torch, config.floatX))
)
assert len(res) == 2
assert op1.called == 2
assert op2.called == 2
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_shared(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")
with torch.device(device):
a = shared(np.array([1, 2, 3], dtype=config.floatX))
pytensor_torch_fn = function([], a, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value())
pytensor_torch_fn = function([], a * 2, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2)
new_a_value = np.array([3, 4, 5], dtype=config.floatX)
a.set_value(new_a_value)
pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_shared_updates(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")
with torch.device(device):
a = shared(0)
pytensor_torch_fn = function([], a, updates={a: a + 1}, mode="PYTORCH")
res1, res2 = pytensor_torch_fn(), pytensor_torch_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2
assert isinstance(a.get_value(), np.ndarray)
a.set_value(5)
res1, res2 = pytensor_torch_fn(), pytensor_torch_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7
assert isinstance(a.get_value(), np.ndarray)
def test_pytorch_checkandraise():
check_and_raise = CheckAndRaise(AssertionError, "testing")
x = scalar("x")
conds = (x > 0, x > 3)
y = check_and_raise(x, *conds)
y_fn = function([x], y, mode="PYTORCH")
with pytest.raises(AssertionError, match="testing"):
y_fn(0.0)
assert y_fn(4).item() == 4
import numpy as np
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.type import matrix, tensor, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
def test_pytorch_Dimshuffle():
a_pt = matrix("a")
x = a_pt.T
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
x = a_pt.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = a_pt.dimshuffle((0,))
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_multiple_input_output():
x = vector("x")
y = vector("y")
out = pt.mul(x, y)
fg = FunctionGraph(outputs=[out], clone=False)
compare_pytorch_and_py(fg, [[1.5], [2.5]])
x = vector("x")
y = vector("y")
div = pt.int_div(x, y)
pt_sum = pt.add(y, x)
fg = FunctionGraph(outputs=[div, pt_sum], clone=False)
compare_pytorch_and_py(fg, [[1.5], [2.5]])
def test_pytorch_elemwise():
x = pt.vector("x")
out = pt.log(1 - x)
fg = FunctionGraph([x], [out])
compare_pytorch_and_py(fg, [[0.9, 0.9]])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论