提交 155db9f3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement basic labeled tensor functionality

上级 f7cf2734
......@@ -82,11 +82,12 @@ jobs:
install-numba: [0]
install-jax: [0]
install-torch: [0]
install-xarray: [0]
part:
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse"
- "tests --ignore=tests/tensor --ignore=tests/scan --ignore=tests/sparse --ignore=tests/xtensor"
- "tests/scan"
- "tests/sparse"
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_math_scipy.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_elemwise.py"
- "tests/tensor --ignore=tests/tensor/conv --ignore=tests/tensor/rewriting --ignore=tests/tensor/test_math.py --ignore=tests/tensor/test_basic.py --ignore=tests/tensor/test_inplace.py --ignore=tests/tensor/test_blas.py --ignore=tests/tensor/test_elemwise.py --ignore=tests/tensor/test_math_scipy.py"
- "tests/tensor/conv"
- "tests/tensor/rewriting"
- "tests/tensor/test_math.py"
......@@ -115,6 +116,7 @@ jobs:
install-numba: 0
install-jax: 0
install-torch: 0
install-xarray: 0
- install-numba: 1
os: "ubuntu-latest"
python-version: "3.10"
......@@ -150,6 +152,13 @@ jobs:
fast-compile: 0
float32: 0
part: "tests/link/pytorch"
- install-xarray: 1
os: "ubuntu-latest"
python-version: "3.13"
numpy-version: ">=2.0"
fast-compile: 0
float32: 0
part: "tests/xtensor"
- os: macos-15
python-version: "3.13"
numpy-version: ">=2.0"
......@@ -196,6 +205,7 @@ jobs:
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tensorflow-probability; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
pip install pytest-sphinx
pip install -e ./
......@@ -212,6 +222,7 @@ jobs:
INSTALL_NUMBA: ${{ matrix.install-numba }}
INSTALL_JAX: ${{ matrix.install-jax }}
INSTALL_TORCH: ${{ matrix.install-torch}}
INSTALL_XARRAY: ${{ matrix.install-xarray }}
OS: ${{ matrix.os}}
- name: Run tests
......
......@@ -67,6 +67,8 @@ exclude = []
if not config.cxx:
exclude = ["cxx_only"]
OPT_NONE = RewriteDatabaseQuery(include=[], exclude=exclude)
# Minimum set of rewrites needed to evaluate a function. This is needed for graphs with "dummy" Operations
OPT_MINIMUM = RewriteDatabaseQuery(include=["minimum_compile"], exclude=exclude)
# Even if multiple merge optimizer call will be there, this shouldn't
# impact performance.
OPT_MERGE = RewriteDatabaseQuery(include=["merge"], exclude=exclude)
......@@ -77,6 +79,7 @@ OPT_FAST_COMPILE = RewriteDatabaseQuery(include=["fast_compile"], exclude=exclud
OPT_STABILIZE = RewriteDatabaseQuery(include=["fast_run"], exclude=exclude)
OPT_STABILIZE.position_cutoff = 1.5000001
OPT_NONE.name = "OPT_NONE"
OPT_MINIMUM.name = "OPT_MINIMUM"
OPT_MERGE.name = "OPT_MERGE"
OPT_FAST_RUN.name = "OPT_FAST_RUN"
OPT_FAST_RUN_STABLE.name = "OPT_FAST_RUN_STABLE"
......@@ -95,6 +98,7 @@ predefined_optimizers = {
None: OPT_NONE,
"None": OPT_NONE,
"merge": OPT_MERGE,
"minimum_compile": OPT_MINIMUM,
"o4": OPT_FAST_RUN,
"o3": OPT_O3,
"o2": OPT_O2,
......@@ -191,6 +195,7 @@ optdb.register(
"merge1", MergeOptimizer(), "fast_run", "fast_compile", "merge", position=0
)
# After scan1 opt at 0.5 and before ShapeOpt at 1
# This should only remove nodes.
# The opt should not do anything that need shape inference.
......
import warnings
import pytensor.xtensor.rewriting
from pytensor.xtensor.type import (
XTensorType,
as_xtensor,
xtensor,
xtensor_constant,
)
warnings.warn("xtensor module is experimental and full of bugs")
from collections.abc import Sequence
from pytensor.compile.ops import TypeCastingOp
from pytensor.graph import Apply, Op
from pytensor.tensor.type import TensorType
from pytensor.xtensor.type import XTensorType, as_xtensor, xtensor
class XOp(Op):
"""A base class for XOps that shouldn't be materialized"""
def perform(self, node, inputs, outputs):
raise NotImplementedError(
f"xtensor operation {self} must be lowered to equivalent tensor operations"
)
class XTypeCastOp(TypeCastingOp):
"""Base class for Ops that type cast between TensorType and XTensorType.
This is like a `ViewOp` but without the expectation the input and output have identical types.
"""
class TensorFromXTensor(XTypeCastOp):
__props__ = ()
def make_node(self, x):
if not isinstance(x.type, XTensorType):
raise TypeError(f"x must be have an XTensorType, got {type(x.type)}")
output = TensorType(x.type.dtype, shape=x.type.shape)()
return Apply(self, [x], [output])
def L_op(self, inputs, outs, g_outs):
[x] = inputs
[g_out] = g_outs
return [xtensor_from_tensor(g_out, dims=x.type.dims)]
tensor_from_xtensor = TensorFromXTensor()
class XTensorFromTensor(XTypeCastOp):
__props__ = ("dims",)
def __init__(self, dims: Sequence[str]):
super().__init__()
self.dims = tuple(dims)
def make_node(self, x):
if not isinstance(x.type, TensorType):
raise TypeError(f"x must be an TensorType type, got {type(x.type)}")
output = xtensor(dtype=x.type.dtype, dims=self.dims, shape=x.type.shape)
return Apply(self, [x], [output])
def L_op(self, inputs, outs, g_outs):
[g_out] = g_outs
return [tensor_from_xtensor(g_out)]
def xtensor_from_tensor(x, dims, name=None):
return XTensorFromTensor(dims=dims)(x, name=name)
class Rename(XTypeCastOp):
__props__ = ("new_dims",)
def __init__(self, new_dims: tuple[str, ...]):
super().__init__()
self.new_dims = new_dims
def make_node(self, x):
x = as_xtensor(x)
output = x.type.clone(dims=self.new_dims)()
return Apply(self, [x], [output])
def L_op(self, inputs, outs, g_outs):
[x] = inputs
[g_out] = g_outs
return [rename(g_out, dims=x.type.dims)]
def rename(x, name_dict: dict[str, str] | None = None, **names: str):
if name_dict is not None:
if names:
raise ValueError("Cannot use both positional and keyword names in rename")
names = name_dict
x = as_xtensor(x)
old_names = x.type.dims
new_names = list(old_names)
for old_name, new_name in names.items():
try:
new_names[old_names.index(old_name)] = new_name
except ValueError:
raise ValueError(
f"Cannot rename {old_name} to {new_name}: {old_name} not in {old_names}"
)
return Rename(tuple(new_names))(x)
# XTensor Module
This module implements as abstraction layer on regular tensor operations, that behaves like Xarray.
A new type `XTensorType`, generalizes the `TensorType` with the addition of a `dims` attribute,
that labels the dimensions of the tensor.
Variables of `XTensorType` (i.e., `XTensorVariable`s) are the symbolic counterpart to xarray DataArray objects.
The module implements several PyTensor operations `XOp`s, whose signature mimics that of xarray (and xarray_einstants) DataArray operations.
These operations, unlike most regular PyTensor operations, cannot be directly evaluated, but require a rewrite (lowering) into
a regular tensor graph that can itself be evaluated as usual.
Like regular PyTensor, we don't need an Op for every possible method or function in the public API of xarray.
If the existing XOps can be composed to produce the desired result, then we can use them directly.
## Coordinates
For now, there's no analogous of xarray coordinates, so you won't be able to do coordinate operations like `.sel`.
The graphs produced by an xarray program without coords are much more amenable to the numpy-like backend of PyTensor.
Coords involve aspects of Pandas/database query and joining that are not trivially expressible in PyTensor.
## Example
```python
import pytensor.tensor as pt
import pytensor.xtensor as px
a = pt.tensor("a", shape=(3,))
b = pt.tensor("b", shape=(4,))
ax = px.as_xtensor(a, dims=["x"])
bx = px.as_xtensor(b, dims=["y"])
zx = ax + bx
assert zx.type == px.type.XTensorType("float64", dims=["x", "y"], shape=(3, 4))
z = zx.values
z.dprint()
# TensorFromXTensor [id A]
# └─ XElemwise{scalar_op=Add()} [id B]
# ├─ XTensorFromTensor{dims=('x',)} [id C]
# │ └─ a [id D]
# └─ XTensorFromTensor{dims=('y',)} [id E]
# └─ b [id F]
```
Once we compile the graph, no `XOp`s are left.
```python
import pytensor
with pytensor.config.change_flags(optimizer_verbose=True):
fn = pytensor.function([a, b], z)
# rewriting: rewrite lower_elemwise replaces XElemwise{scalar_op=Add()}.0 of XElemwise{scalar_op=Add()}(XTensorFromTensor{dims=('x',)}.0, XTensorFromTensor{dims=('y',)}.0) with XTensorFromTensor{dims=('x', 'y')}.0 of XTensorFromTensor{dims=('x', 'y')}(Add.0)
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x',)}.0) with a of None
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('y',)}.0) with b of None
# rewriting: rewrite useless_tensor_from_xtensor replaces TensorFromXTensor.0 of TensorFromXTensor(XTensorFromTensor{dims=('x', 'y')}.0) with Add.0 of Add(ExpandDims{axis=1}.0, ExpandDims{axis=0}.0)
fn.dprint()
# Add [id A] 2
# ├─ ExpandDims{axis=1} [id B] 1
# │ └─ a [id C]
# └─ ExpandDims{axis=0} [id D] 0
# └─ b [id E]
```
import pytensor.xtensor.rewriting.basic
from pytensor.graph import node_rewriter
from pytensor.tensor.basic import register_infer_shape
from pytensor.tensor.rewriting.basic import register_canonicalize, register_useless
from pytensor.xtensor.basic import (
Rename,
TensorFromXTensor,
XTensorFromTensor,
xtensor_from_tensor,
)
from pytensor.xtensor.rewriting.utils import register_lower_xtensor
@register_infer_shape
@register_useless
@register_canonicalize
@register_lower_xtensor
@node_rewriter(tracks=[TensorFromXTensor])
def useless_tensor_from_xtensor(fgraph, node):
"""TensorFromXTensor(XTensorFromTensor(x)) -> x"""
[x] = node.inputs
if x.owner and isinstance(x.owner.op, XTensorFromTensor):
return [x.owner.inputs[0]]
@register_infer_shape
@register_useless
@register_canonicalize
@register_lower_xtensor
@node_rewriter(tracks=[XTensorFromTensor])
def useless_xtensor_from_tensor(fgraph, node):
"""XTensorFromTensor(TensorFromXTensor(x)) -> x"""
[x] = node.inputs
if x.owner and isinstance(x.owner.op, TensorFromXTensor):
return [x.owner.inputs[0]]
@register_lower_xtensor
@node_rewriter(tracks=[TensorFromXTensor])
def useless_tensor_from_xtensor_of_rename(fgraph, node):
"""TensorFromXTensor(Rename(x)) -> TensorFromXTensor(x)"""
[renamed_x] = node.inputs
if renamed_x.owner and isinstance(renamed_x.owner.op, Rename):
[x] = renamed_x.owner.inputs
return node.op(x, return_list=True)
@register_lower_xtensor
@node_rewriter(tracks=[Rename])
def useless_rename(fgraph, node):
"""
Rename(Rename(x, inner_dims), outer_dims) -> Rename(x, outer_dims)
Rename(X, XTensorFromTensor(x, inner_dims), outer_dims) -> XTensorFrom_tensor(x, outer_dims)
"""
[renamed_x] = node.inputs
if renamed_x.owner:
if isinstance(renamed_x.owner.op, Rename):
[x] = renamed_x.owner.inputs
return [node.op(x)]
elif isinstance(renamed_x.owner.op, TensorFromXTensor):
[x] = renamed_x.owner.inputs
return [xtensor_from_tensor(x, dims=node.op.new_dims)]
from pytensor.compile import optdb
from pytensor.graph.rewriting.basic import NodeRewriter
from pytensor.graph.rewriting.db import EquilibriumDB, RewriteDatabase
lower_xtensor_db = EquilibriumDB(ignore_newtrees=False)
optdb.register(
"lower_xtensor",
lower_xtensor_db,
"fast_run",
"fast_compile",
"minimum_compile",
position=0.1,
)
def register_lower_xtensor(
node_rewriter: RewriteDatabase | NodeRewriter | str, *tags: str, **kwargs
):
if isinstance(node_rewriter, str):
def register(inner_rewriter: RewriteDatabase | NodeRewriter):
return register_lower_xtensor(
inner_rewriter, node_rewriter, *tags, **kwargs
)
return register
else:
name = kwargs.pop("name", None) or node_rewriter.__name__ # type: ignore
lower_xtensor_db.register(
name,
node_rewriter,
"fast_run",
"fast_compile",
"minimum_compile",
*tags,
**kwargs,
)
return node_rewriter
差异被折叠。
# ruff: noqa: E402
import pytest
pytest.importorskip("xarray")
import numpy as np
from xarray import DataArray
from pytensor.graph.basic import equal_computations
from pytensor.tensor import as_tensor, specify_shape, tensor
from pytensor.xtensor import xtensor
from pytensor.xtensor.type import XTensorType, as_xtensor
def test_xtensortype():
x1 = XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))
x2 = XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))
x3 = XTensorType(dtype="float64", dims=("a", "b"), shape=(None, 3))
y1 = XTensorType(dtype="float64", dims=("c", "d"), shape=(4, 5))
z1 = XTensorType(dtype="float32", dims=("a", "b"), shape=(2, 3))
assert x1 == x2 and x1.is_super(x2) and x2.is_super(x1)
assert x1 != x3 and not x1.is_super(x3) and x3.is_super(x1)
assert x1 != y1 and not x1.is_super(y1) and not y1.is_super(x1)
assert x1 != z1 and not x1.is_super(z1) and not z1.is_super(x1)
def test_xtensortype_filter_variable():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y1 = xtensor("y1", dims=("a", "b"), shape=(2, 3))
assert x.type.filter_variable(y1) is y1
y2 = xtensor("y2", dims=("b", "a"), shape=(3, 2))
expected_y2 = as_xtensor(y2.values.transpose(), dims=("a", "b"))
assert equal_computations([x.type.filter_variable(y2)], [expected_y2])
y3 = xtensor("y3", dims=("b", "a"), shape=(3, None))
expected_y3 = as_xtensor(
specify_shape(
as_xtensor(y3.values.transpose(), dims=("a", "b")).values, (2, 3)
),
dims=("a", "b"),
)
assert equal_computations([x.type.filter_variable(y3)], [expected_y3])
# Cases that fail
with pytest.raises(TypeError):
y4 = xtensor("y4", dims=("a", "b"), shape=(3, 2))
x.type.filter_variable(y4)
with pytest.raises(TypeError):
y5 = xtensor("y5", dims=("a", "c"), shape=(2, 3))
x.type.filter_variable(y5)
with pytest.raises(TypeError):
y6 = xtensor("y6", dims=("a", "b", "c"), shape=(2, 3, 4))
x.type.filter_variable(y6)
with pytest.raises(TypeError):
y7 = xtensor("y7", dims=("a", "b"), shape=(2, 3), dtype="int32")
x.type.filter_variable(y7)
z1 = tensor("z1", shape=(2, None))
expected_z1 = as_xtensor(specify_shape(z1, (2, 3)), dims=("a", "b"))
assert equal_computations([x.type.filter_variable(z1)], [expected_z1])
# Cases that fail
with pytest.raises(TypeError):
z2 = tensor("z2", shape=(3, 2))
x.type.filter_variable(z2)
with pytest.raises(TypeError):
z3 = tensor("z3", shape=(1, 2, 3))
x.type.filter_variable(z3)
with pytest.raises(TypeError):
z4 = tensor("z4", shape=(2, 3), dtype="int32")
x.type.filter_variable(z4)
def test_xtensor_constant():
x = as_xtensor(DataArray(np.ones((2, 3)), dims=("a", "b")))
assert x.type == XTensorType(dtype="float64", dims=("a", "b"), shape=(2, 3))
y = as_xtensor(np.ones((2, 3)), dims=("a", "b"))
assert y.type == x.type
assert x.signature() == y.signature()
assert x.equals(y)
x_eval = x.eval()
assert isinstance(x.eval(), np.ndarray)
np.testing.assert_array_equal(x_eval, y.eval(), strict=True)
z = as_xtensor(np.ones((3, 2)), dims=("b", "a"))
assert z.type != x.type
assert z.signature() != x.signature()
assert not x.equals(z)
np.testing.assert_array_equal(x_eval, z.eval().T, strict=True)
def test_as_tensor():
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
with pytest.raises(
TypeError,
match="PyTensor forbids automatic conversion of XTensorVariable to TensorVariable",
):
as_tensor(x)
x_pt = as_tensor(x, allow_xtensor_conversion=True)
assert equal_computations([x_pt], [x.values])
def test_minimum_compile():
from pytensor.compile.mode import Mode
x = xtensor("x", dims=("a", "b"), shape=(2, 3))
y = as_xtensor(x.values.transpose(), dims=("b", "a"))
minimum_mode = Mode(linker="py", optimizer="minimum_compile")
result = y.eval({"x": np.ones((2, 3))}, mode=minimum_mode)
np.testing.assert_array_equal(result, np.ones((3, 2)))
# ruff: noqa: E402
import pytest
pytest.importorskip("xarray")
import numpy as np
from xarray import DataArray
from xarray.testing import assert_allclose
from pytensor import function
from pytensor.xtensor.type import XTensorType
def xr_function(*args, **kwargs):
"""Compile and wrap a PyTensor function to return xarray DataArrays."""
fn = function(*args, **kwargs)
symbolic_outputs = fn.maker.fgraph.outputs
assert all(
isinstance(out.type, XTensorType) for out in symbolic_outputs
), "All outputs must be xtensor"
def xfn(*xr_inputs):
np_inputs = [
inp.values if isinstance(inp, DataArray) else inp for inp in xr_inputs
]
np_outputs = fn(*np_inputs)
if not isinstance(np_outputs, tuple | list):
return DataArray(np_outputs, dims=symbolic_outputs[0].type.dims)
else:
return tuple(
DataArray(res, dims=out.type.dims)
for res, out in zip(np_outputs, symbolic_outputs)
)
xfn.fn = fn
return xfn
def xr_assert_allclose(x, y, *args, **kwargs):
# Assert that two xarray DataArrays are close, ignoring coordinates
x = x.drop_vars(x.coords)
y = y.drop_vars(y.coords)
assert_allclose(x, y, *args, **kwargs)
def xr_arange_like(x):
return DataArray(
np.arange(np.prod(x.type.shape), dtype=x.type.dtype).reshape(x.type.shape),
dims=x.type.dims,
)
def xr_random_like(x, rng=None):
if rng is None:
rng = np.random.default_rng()
return DataArray(
rng.standard_normal(size=x.type.shape, dtype=x.type.dtype), dims=x.type.dims
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论