Unverified 提交 b67ff220 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: GitHub

Implement `wrap_jax` and rename `as_op` to `wrap_py` (#1614)

上级 7779b07b
......@@ -208,7 +208,7 @@ jobs:
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
fi
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tfp-nightly; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
......
......@@ -25,4 +25,4 @@ dependencies:
- ablog
- pip
- pip:
- -e ..
- -e ..[jax]
......@@ -803,10 +803,10 @@ You can omit the :meth:`Rop` functions. Try to implement the testing apparatus d
:download:`Solution<extending_pytensor_solution_1.py>`
:func:`as_op`
:func:`wrap_py`
-------------
:func:`as_op` is a Python decorator that converts a Python function into a
:func:`wrap_py` is a Python decorator that converts a Python function into a
basic PyTensor :class:`Op` that will call the supplied function during execution.
This isn't the recommended way to build an :class:`Op`, but allows for a quick implementation.
......@@ -839,11 +839,11 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature
inputs PyTensor variables that were declared.
.. note::
The python function wrapped by the :func:`as_op` decorator needs to return a new
The python function wrapped by the :func:`wrap_py` decorator needs to return a new
data allocation, no views or in place modification of the input.
:func:`as_op` Example
:func:`wrap_py` Example
^^^^^^^^^^^^^^^^^^^^^
.. testcode:: asop
......@@ -852,14 +852,14 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature
import pytensor.tensor as pt
import numpy as np
from pytensor import function
from pytensor.compile.ops import as_op
from pytensor.compile.ops import wrap_py
def infer_shape_numpy_dot(fgraph, node, input_shapes):
ashp, bshp = input_shapes
return [ashp[:-1] + bshp[-1:]]
@as_op(
@wrap_py(
itypes=[pt.dmatrix, pt.dmatrix],
otypes=[pt.dmatrix],
infer_shape=infer_shape_numpy_dot,
......
......@@ -167,9 +167,9 @@ class TestSumDiffOp(utt.InferShapeTester):
import numpy as np
# as_op exercice
# wrap_py exercice
import pytensor
from pytensor.compile.ops import as_op
from pytensor.compile.ops import wrap_py
def infer_shape_numpy_dot(fgraph, node, input_shapes):
......@@ -177,7 +177,7 @@ def infer_shape_numpy_dot(fgraph, node, input_shapes):
return [ashp[:-1] + bshp[-1:]]
@as_op(
@wrap_py(
itypes=[pt.fmatrix, pt.fmatrix],
otypes=[pt.fmatrix],
infer_shape=infer_shape_numpy_dot,
......@@ -192,7 +192,7 @@ def infer_shape_numpy_add_sub(fgraph, node, input_shapes):
return [ashp[0]]
@as_op(
@wrap_py(
itypes=[pt.fmatrix, pt.fmatrix],
otypes=[pt.fmatrix],
infer_shape=infer_shape_numpy_add_sub,
......@@ -201,7 +201,7 @@ def numpy_add(a, b):
return np.add(a, b)
@as_op(
@wrap_py(
itypes=[pt.fmatrix, pt.fmatrix],
otypes=[pt.fmatrix],
infer_shape=infer_shape_numpy_add_sub,
......
......@@ -61,10 +61,16 @@ Convert to Variable
.. autofunction:: pytensor.as_symbolic(...)
Wrap JAX functions
==================
.. autofunction:: wrap_jax(...)
Alias for :func:`pytensor.link.jax.ops.wrap_jax`
Debug
=====
.. autofunction:: pytensor.dprint(...)
Alias for :func:`pytensor.printing.debugprint`
......@@ -166,7 +166,7 @@ from pytensor.scan import checkpoints
from pytensor.scan.basic import scan
from pytensor.scan.views import foldl, foldr, map, reduce
from pytensor.compile.builders import OpFromGraph
from pytensor.link.jax.ops import wrap_jax
# isort: on
......
......@@ -56,6 +56,7 @@ from pytensor.compile.ops import (
register_deep_copy_op_c_code,
register_view_op_c_code,
view_op,
wrap_py,
)
from pytensor.compile.profiling import ProfileStats
from pytensor.compile.sharedvalue import SharedVariable, shared, shared_constructor
"""
This file contains auxiliary Ops, used during the compilation phase and Ops
building class (:class:`FromFunctionOp`) and decorator (:func:`as_op`) that
building class (:class:`FromFunctionOp`) and decorator (:func:`wrap_py`) that
help make new Ops more rapidly.
"""
......@@ -268,12 +268,12 @@ class FromFunctionOp(Op):
obj = load_back(mod, name)
except (ImportError, KeyError, AttributeError):
raise pickle.PicklingError(
f"Can't pickle as_op(), not found as {mod}.{name}"
f"Can't pickle wrap_py(), not found as {mod}.{name}"
)
else:
if obj is not self:
raise pickle.PicklingError(
f"Can't pickle as_op(), not the object at {mod}.{name}"
f"Can't pickle wrap_py(), not the object at {mod}.{name}"
)
return load_back, (mod, name)
......@@ -282,6 +282,18 @@ class FromFunctionOp(Op):
def as_op(itypes, otypes, infer_shape=None):
import warnings
warnings.warn(
"pytensor.as_op is deprecated and will be removed in a future release. "
"Please use pytensor.wrap_py instead.",
FutureWarning,
stacklevel=2,
)
return wrap_py(itypes, otypes, infer_shape)
def wrap_py(itypes, otypes, infer_shape=None):
"""
Decorator that converts a function into a basic PyTensor op that will call
the supplied function as its implementation.
......@@ -301,7 +313,7 @@ def as_op(itypes, otypes, infer_shape=None):
Examples
--------
@as_op(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix],
@wrap_py(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix],
otypes=[pytensor.tensor.fmatrix])
def numpy_dot(a, b):
return numpy.dot(a, b)
......
......@@ -13,6 +13,7 @@ from pytensor.configdefaults import config
from pytensor.graph import Constant
from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse
from pytensor.link.jax.ops import JAXOp
from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise
......@@ -142,3 +143,8 @@ def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable:
return fgraph_fn(*inputs)
return opfromgraph
@jax_funcify.register(JAXOp)
def jax_op_funcify(op, **kwargs):
return op.perform_jax
差异被折叠。
import pickle
import numpy as np
import pytest
from pytensor import function
from pytensor.compile.ops import as_op
from pytensor.compile.ops import as_op, wrap_py
from pytensor.tensor.type import dmatrix, dvector
from tests import unittest_tools as utt
@as_op([dmatrix, dmatrix], dmatrix)
@wrap_py([dmatrix, dmatrix], dmatrix)
def mul(a, b):
"""
This is for test_pickle, since the function still has to be
......@@ -21,6 +22,21 @@ class TestOpDecorator(utt.InferShapeTester):
def test_1arg(self):
x = dmatrix("x")
@wrap_py(dmatrix, dvector)
def cumprod(x):
return np.cumprod(x)
fn = function([x], cumprod(x))
r = fn([[1.5, 5], [2, 2]])
r0 = np.array([1.5, 7.5, 15.0, 30.0])
assert np.allclose(r, r0), (r, r0)
def test_deprecation(self):
x = dmatrix("x")
with pytest.warns(FutureWarning):
@as_op(dmatrix, dvector)
def cumprod(x):
return np.cumprod(x)
......@@ -37,7 +53,7 @@ class TestOpDecorator(utt.InferShapeTester):
y = dvector("y")
y.tag.test_value = [0, 0, 0, 0]
@as_op([dmatrix, dvector], dvector)
@wrap_py([dmatrix, dvector], dvector)
def cumprod_plus(x, y):
return np.cumprod(x) + y
......@@ -57,7 +73,7 @@ class TestOpDecorator(utt.InferShapeTester):
x, y = shapes
return [y]
@as_op([dmatrix, dvector], dvector, infer_shape)
@wrap_py([dmatrix, dvector], dvector, infer_shape)
def cumprod_plus(x, y):
return np.cumprod(x) + y
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论