Unverified 提交 320bac49 authored 作者: Harshvir Sandhu's avatar Harshvir Sandhu 提交者: GitHub

Add initial support for PyTorch backend (#764)

上级 efa845a3
......@@ -76,6 +76,7 @@ jobs:
float32: [0, 1]
install-numba: [0]
install-jax: [0]
install-torch: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests/scan"
......@@ -116,6 +117,11 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/jax"
- install-torch: 1
python-version: "3.10"
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
steps:
- uses: actions/checkout@v4
with:
......@@ -142,9 +148,12 @@ jobs:
- name: Install dependencies
shell: micromamba-shell {0}
run: |
micromamba install --yes -q "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock sympy
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" pytorch pytorch-cuda=12.1 -c pytorch -c nvidia; fi
pip install -e ./
micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
......@@ -153,6 +162,7 @@ jobs:
PYTHON_VERSION: ${{ matrix.python-version }}
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
- name: Run tests
shell: micromamba-shell {0}
......@@ -199,7 +209,7 @@ jobs:
- name: Install dependencies
shell: micromamba-shell {0}
run: |
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark
micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}=*_cpython" mkl numpy scipy pip mkl-service cython pytest "numba>=0.57" jax jaxlib pytest-benchmark pytorch pytorch-cuda=12.1 -c pytorch -c nvidia
pip install -e ./
micromamba list && pip freeze
python -c 'import pytensor; print(pytensor.config.__str__(print_doc=False))'
......@@ -268,3 +278,4 @@ jobs:
directory: ./coverage/
fail_ci_if_error: true
token: ${{ secrets.CODECOV_TOKEN }}
......@@ -28,6 +28,7 @@ from pytensor.link.basic import Linker, PerformLinker
from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.link.jax.linker import JAXLinker
from pytensor.link.numba.linker import NumbaLinker
from pytensor.link.pytorch.linker import PytorchLinker
from pytensor.link.vm import VMLinker
......@@ -47,6 +48,7 @@ predefined_linkers = {
"vm_nogc": VMLinker(allow_gc=False, use_cloop=False),
"cvm_nogc": VMLinker(allow_gc=False, use_cloop=True),
"jax": JAXLinker(),
"pytorch": PytorchLinker(),
"numba": NumbaLinker(),
}
......@@ -460,6 +462,18 @@ JAX = Mode(
],
),
)
PYTORCH = Mode(
PytorchLinker(),
RewriteDatabaseQuery(
include=["fast_run"],
exclude=[
"cxx_only",
"BlasOpt",
"fusion",
"inplace",
],
),
)
NUMBA = Mode(
NumbaLinker(),
RewriteDatabaseQuery(
......@@ -474,6 +488,7 @@ predefined_modes = {
"FAST_RUN": FAST_RUN,
"JAX": JAX,
"NUMBA": NUMBA,
"PYTORCH": PYTORCH,
}
instantiated_default_mode = None
......
......@@ -600,6 +600,10 @@ class JITLinker(PerformLinker):
def jit_compile(self, fn: Callable) -> Callable:
"""JIT compile a converted ``FunctionGraph``."""
def input_filter(self, inp: Any) -> Any:
"""Apply a filter to the data input."""
return inp
def output_filter(self, var: Variable, out: Any) -> Any:
"""Apply a filter to the data output by a JITed function call."""
return out
......@@ -657,7 +661,7 @@ class JITLinker(PerformLinker):
thunk_inputs=thunk_inputs,
thunk_outputs=thunk_outputs,
):
outputs = fgraph_jit(*[x[0] for x in thunk_inputs])
outputs = fgraph_jit(*[self.input_filter(x[0]) for x in thunk_inputs])
for o_var, o_storage, o_val in zip(fgraph.outputs, thunk_outputs, outputs):
compute_map[o_var][0] = True
......
# isort: off
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify, pytorch_typify
# # Load dispatch specializations
import pytensor.link.pytorch.dispatch.scalar
import pytensor.link.pytorch.dispatch.elemwise
# isort: on
from functools import singledispatch
import torch
from pytensor.compile.ops import DeepCopyOp
from pytensor.graph.fg import FunctionGraph
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
@singledispatch
def pytorch_typify(data, dtype=None, **kwargs):
r"""Convert instances of PyTensor `Type`\s to PyTorch types."""
return torch.as_tensor(data, dtype=dtype)
@singledispatch
def pytorch_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a PyTorch compatible function from an PyTensor `Op`."""
raise NotImplementedError(
f"No PyTorch conversion for the given `Op`: {op}.\nCheck out `https://github.com/pymc-devs/pytensor/issues/821` for progress or to request we prioritize this operation"
)
@pytorch_funcify.register(FunctionGraph)
def pytorch_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="pytorch_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
pytorch_funcify,
type_conversion_fn=pytorch_typify,
fgraph_name=fgraph_name,
**kwargs,
)
@pytorch_funcify.register(CheckAndRaise)
def pytorch_funcify_CheckAndRaise(op, **kwargs):
error = op.exc_type
msg = op.msg
def assert_fn(x, *conditions):
for cond in conditions:
if not cond.item():
raise error(msg)
return x
return assert_fn
@pytorch_funcify.register(DeepCopyOp)
def pytorch_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return x.clone()
return deepcopyop
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.tensor.elemwise import DimShuffle, Elemwise
@pytorch_funcify.register(Elemwise)
def pytorch_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
base_fn = pytorch_funcify(scalar_op, node=node, **kwargs)
def elemwise_fn(*inputs):
Elemwise._check_runtime_broadcast(node, inputs)
return base_fn(*inputs)
return elemwise_fn
@pytorch_funcify.register(DimShuffle)
def pytorch_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = torch.permute(x, op.transposition)
shape = list(res.shape[: len(op.shuffle)])
for augm in op.augment:
shape.insert(augm, 1)
res = torch.reshape(res, shape)
if not op.inplace:
res = res.clone()
return res
return dimshuffle
import torch
from pytensor.link.pytorch.dispatch.basic import pytorch_funcify
from pytensor.scalar.basic import (
ScalarOp,
)
@pytorch_funcify.register(ScalarOp)
def pytorch_funcify_ScalarOp(op, node, **kwargs):
"""Return pytorch function that implements the same computation as the Scalar Op.
This dispatch is expected to return a pytorch function that works on Array inputs as Elemwise does,
even though it's dispatched on the Scalar Op.
"""
nfunc_spec = getattr(op, "nfunc_spec", None)
if nfunc_spec is None:
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
func_name = nfunc_spec[0]
pytorch_func = getattr(torch, func_name)
if len(node.inputs) > op.nfunc_spec[1]:
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
# even though the base Op from `func_name` is specified as a binary Op.
# This happens with `Add`, which can work as a `Sum` for multiple scalars.
pytorch_variadic_func = getattr(torch, op.nfunc_variadic, None)
if not pytorch_variadic_func:
raise NotImplementedError(
f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs"
)
def pytorch_func(*args):
return pytorch_variadic_func(
torch.stack(torch.broadcast_tensors(*args), axis=0), axis=0
)
return pytorch_func
from typing import Any
from pytensor.graph.basic import Variable
from pytensor.link.basic import JITLinker
class PytorchLinker(JITLinker):
"""A `Linker` that compiles NumPy-based operations using torch.compile."""
def input_filter(self, inp: Any) -> Any:
from pytensor.link.pytorch.dispatch import pytorch_typify
return pytorch_typify(inp)
def output_filter(self, var: Variable, out: Any) -> Any:
return out.cpu()
def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify
return pytorch_funcify(
fgraph, input_storage=input_storage, storage_map=storage_map, **kwargs
)
def jit_compile(self, fn):
import torch
return torch.compile(fn)
def create_thunk_inputs(self, storage_map):
thunk_inputs = []
for n in self.fgraph.inputs:
sinput = storage_map[n]
thunk_inputs.append(sinput)
return thunk_inputs
from collections.abc import Callable, Iterable
from functools import partial
import numpy as np
import pytest
from pytensor.compile.function import function
from pytensor.compile.mode import get_mode
from pytensor.compile.sharedvalue import SharedVariable, shared
from pytensor.configdefaults import config
from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import Op
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.type import scalar, vector
torch = pytest.importorskip("torch")
pytorch_mode = get_mode("PYTORCH")
py_mode = get_mode("FAST_COMPILE")
def compare_pytorch_and_py(
fgraph: FunctionGraph,
test_inputs: Iterable,
assert_fn: Callable | None = None,
must_be_device_array: bool = True,
pytorch_mode=pytorch_mode,
py_mode=py_mode,
):
"""Function to compare python graph output and pytorch compiled output for testing equality
Parameters
----------
fgraph: FunctionGraph
PyTensor function Graph object
test_inputs: iter
Numerical inputs for testing the function graph
assert_fn: func, opt
Assert function used to check for equality between python and pytorch. If not
provided uses np.testing.assert_allclose
must_be_device_array: Bool
Checks if torch.device.type is cuda
"""
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
pytensor_torch_fn = function(fn_inputs, fgraph.outputs, mode=pytorch_mode)
pytorch_res = pytensor_torch_fn(*test_inputs)
if must_be_device_array:
if isinstance(pytorch_res, list):
assert all(isinstance(res, torch.Tensor) for res in pytorch_res)
else:
assert pytorch_res.device.type == "cuda"
pytensor_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = pytensor_py_fn(*test_inputs)
if len(fgraph.outputs) > 1:
for j, p in zip(pytorch_res, py_res):
assert_fn(j.cpu(), p)
else:
assert_fn([pytorch_res[0].cpu()], py_res)
return pytensor_torch_fn, pytorch_res
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_pytorch_FunctionGraph_once(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")
"""Make sure that an output is only computed once when it's referenced multiple times."""
from pytensor.link.pytorch.dispatch import pytorch_funcify
with torch.device(device):
x = vector("x")
y = vector("y")
class TestOp(Op):
def __init__(self):
self.called = 0
def make_node(self, *args):
return Apply(self, list(args), [x.type() for x in args])
def perform(self, inputs, outputs):
for i, inp in enumerate(inputs):
outputs[i][0] = inp[0]
@pytorch_funcify.register(TestOp)
def pytorch_funcify_TestOp(op, **kwargs):
def func(*args, op=op):
op.called += 1
for arg in args:
assert arg.device.type == device
return list(args)
return func
op1 = TestOp()
op2 = TestOp()
q, r = op1(x, y)
outs = op2(q + r, q + r)
out_fg = FunctionGraph([x, y], outs, clone=False)
assert len(out_fg.outputs) == 2
out_torch = pytorch_funcify(out_fg)
x_val = torch.tensor([1, 2]).to(getattr(torch, config.floatX))
y_val = torch.tensor([2, 3]).to(getattr(torch, config.floatX))
res = out_torch(x_val, y_val)
for output in res:
assert torch.equal(
output, torch.tensor([3, 5]).to(getattr(torch, config.floatX))
)
assert len(res) == 2
assert op1.called == 1
assert op2.called == 1
res = out_torch(x_val, y_val)
for output in res:
assert torch.equal(
output, torch.tensor([3, 5]).to(getattr(torch, config.floatX))
)
assert len(res) == 2
assert op1.called == 2
assert op2.called == 2
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_shared(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")
with torch.device(device):
a = shared(np.array([1, 2, 3], dtype=config.floatX))
pytensor_torch_fn = function([], a, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value())
pytensor_torch_fn = function([], a * 2, mode="PYTORCH")
pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
assert isinstance(a.get_value(), np.ndarray)
np.testing.assert_allclose(pytorch_res.cpu(), a.get_value() * 2)
new_a_value = np.array([3, 4, 5], dtype=config.floatX)
a.set_value(new_a_value)
pytorch_res = pytensor_torch_fn()
assert isinstance(pytorch_res, torch.Tensor)
np.testing.assert_allclose(pytorch_res.cpu(), new_a_value * 2)
@pytest.mark.parametrize("device", ["cpu", "cuda"])
def test_shared_updates(device):
if device == "cuda" and not torch.cuda.is_available():
pytest.skip("CUDA is not available")
with torch.device(device):
a = shared(0)
pytensor_torch_fn = function([], a, updates={a: a + 1}, mode="PYTORCH")
res1, res2 = pytensor_torch_fn(), pytensor_torch_fn()
assert res1 == 0
assert res2 == 1
assert a.get_value() == 2
assert isinstance(a.get_value(), np.ndarray)
a.set_value(5)
res1, res2 = pytensor_torch_fn(), pytensor_torch_fn()
assert res1 == 5
assert res2 == 6
assert a.get_value() == 7
assert isinstance(a.get_value(), np.ndarray)
def test_pytorch_checkandraise():
check_and_raise = CheckAndRaise(AssertionError, "testing")
x = scalar("x")
conds = (x > 0, x > 3)
y = check_and_raise(x, *conds)
y_fn = function([x], y, mode="PYTORCH")
with pytest.raises(AssertionError, match="testing"):
y_fn(0.0)
assert y_fn(4).item() == 4
import numpy as np
import pytensor.tensor as pt
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.tensor import elemwise as pt_elemwise
from pytensor.tensor.type import matrix, tensor, vector
from tests.link.pytorch.test_basic import compare_pytorch_and_py
def test_pytorch_Dimshuffle():
a_pt = matrix("a")
x = a_pt.T
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
x = a_pt.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = a_pt.dimshuffle((0,))
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_pt = tensor(dtype=config.floatX, shape=(None, 1))
x = pt_elemwise.DimShuffle([False, True], (0,))(a_pt)
x_fg = FunctionGraph([a_pt], [x])
compare_pytorch_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_multiple_input_output():
x = vector("x")
y = vector("y")
out = pt.mul(x, y)
fg = FunctionGraph(outputs=[out], clone=False)
compare_pytorch_and_py(fg, [[1.5], [2.5]])
x = vector("x")
y = vector("y")
div = pt.int_div(x, y)
pt_sum = pt.add(y, x)
fg = FunctionGraph(outputs=[div, pt_sum], clone=False)
compare_pytorch_and_py(fg, [[1.5], [2.5]])
def test_pytorch_elemwise():
x = pt.vector("x")
out = pt.log(1 - x)
fg = FunctionGraph([x], [out])
compare_pytorch_and_py(fg, [[0.9, 0.9]])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论