Unverified 提交 b67ff220 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: GitHub

Implement `wrap_jax` and rename `as_op` to `wrap_py` (#1614)

上级 7779b07b
...@@ -208,7 +208,7 @@ jobs: ...@@ -208,7 +208,7 @@ jobs:
micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx; micromamba install --yes -q "python~=${PYTHON_VERSION}" mkl "numpy${NUMPY_VERSION}" scipy pip mkl-service graphviz cython pytest coverage pytest-cov pytest-benchmark pytest-mock pytest-sphinx;
fi fi
if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi if [[ $INSTALL_NUMBA == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" "numba>=0.57"; fi
if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro && pip install tfp-nightly; fi if [[ $INSTALL_JAX == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" jax jaxlib numpyro equinox && pip install tfp-nightly; fi
if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi if [[ $INSTALL_TORCH == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" pytorch pytorch-cuda=12.1 "mkl<=2024.0" -c pytorch -c nvidia; fi
if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi if [[ $INSTALL_XARRAY == "1" ]]; then micromamba install --yes -q -c conda-forge "python~=${PYTHON_VERSION}" xarray xarray-einstats; fi
......
...@@ -25,4 +25,4 @@ dependencies: ...@@ -25,4 +25,4 @@ dependencies:
- ablog - ablog
- pip - pip
- pip: - pip:
- -e .. - -e ..[jax]
...@@ -803,10 +803,10 @@ You can omit the :meth:`Rop` functions. Try to implement the testing apparatus d ...@@ -803,10 +803,10 @@ You can omit the :meth:`Rop` functions. Try to implement the testing apparatus d
:download:`Solution<extending_pytensor_solution_1.py>` :download:`Solution<extending_pytensor_solution_1.py>`
:func:`as_op` :func:`wrap_py`
------------- -------------
:func:`as_op` is a Python decorator that converts a Python function into a :func:`wrap_py` is a Python decorator that converts a Python function into a
basic PyTensor :class:`Op` that will call the supplied function during execution. basic PyTensor :class:`Op` that will call the supplied function during execution.
This isn't the recommended way to build an :class:`Op`, but allows for a quick implementation. This isn't the recommended way to build an :class:`Op`, but allows for a quick implementation.
...@@ -839,11 +839,11 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature ...@@ -839,11 +839,11 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature
inputs PyTensor variables that were declared. inputs PyTensor variables that were declared.
.. note:: .. note::
The python function wrapped by the :func:`as_op` decorator needs to return a new The python function wrapped by the :func:`wrap_py` decorator needs to return a new
data allocation, no views or in place modification of the input. data allocation, no views or in place modification of the input.
:func:`as_op` Example :func:`wrap_py` Example
^^^^^^^^^^^^^^^^^^^^^ ^^^^^^^^^^^^^^^^^^^^^
.. testcode:: asop .. testcode:: asop
...@@ -852,14 +852,14 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature ...@@ -852,14 +852,14 @@ It takes an optional :meth:`infer_shape` parameter that must have this signature
import pytensor.tensor as pt import pytensor.tensor as pt
import numpy as np import numpy as np
from pytensor import function from pytensor import function
from pytensor.compile.ops import as_op from pytensor.compile.ops import wrap_py
def infer_shape_numpy_dot(fgraph, node, input_shapes): def infer_shape_numpy_dot(fgraph, node, input_shapes):
ashp, bshp = input_shapes ashp, bshp = input_shapes
return [ashp[:-1] + bshp[-1:]] return [ashp[:-1] + bshp[-1:]]
@as_op( @wrap_py(
itypes=[pt.dmatrix, pt.dmatrix], itypes=[pt.dmatrix, pt.dmatrix],
otypes=[pt.dmatrix], otypes=[pt.dmatrix],
infer_shape=infer_shape_numpy_dot, infer_shape=infer_shape_numpy_dot,
......
...@@ -167,9 +167,9 @@ class TestSumDiffOp(utt.InferShapeTester): ...@@ -167,9 +167,9 @@ class TestSumDiffOp(utt.InferShapeTester):
import numpy as np import numpy as np
# as_op exercice # wrap_py exercice
import pytensor import pytensor
from pytensor.compile.ops import as_op from pytensor.compile.ops import wrap_py
def infer_shape_numpy_dot(fgraph, node, input_shapes): def infer_shape_numpy_dot(fgraph, node, input_shapes):
...@@ -177,7 +177,7 @@ def infer_shape_numpy_dot(fgraph, node, input_shapes): ...@@ -177,7 +177,7 @@ def infer_shape_numpy_dot(fgraph, node, input_shapes):
return [ashp[:-1] + bshp[-1:]] return [ashp[:-1] + bshp[-1:]]
@as_op( @wrap_py(
itypes=[pt.fmatrix, pt.fmatrix], itypes=[pt.fmatrix, pt.fmatrix],
otypes=[pt.fmatrix], otypes=[pt.fmatrix],
infer_shape=infer_shape_numpy_dot, infer_shape=infer_shape_numpy_dot,
...@@ -192,7 +192,7 @@ def infer_shape_numpy_add_sub(fgraph, node, input_shapes): ...@@ -192,7 +192,7 @@ def infer_shape_numpy_add_sub(fgraph, node, input_shapes):
return [ashp[0]] return [ashp[0]]
@as_op( @wrap_py(
itypes=[pt.fmatrix, pt.fmatrix], itypes=[pt.fmatrix, pt.fmatrix],
otypes=[pt.fmatrix], otypes=[pt.fmatrix],
infer_shape=infer_shape_numpy_add_sub, infer_shape=infer_shape_numpy_add_sub,
...@@ -201,7 +201,7 @@ def numpy_add(a, b): ...@@ -201,7 +201,7 @@ def numpy_add(a, b):
return np.add(a, b) return np.add(a, b)
@as_op( @wrap_py(
itypes=[pt.fmatrix, pt.fmatrix], itypes=[pt.fmatrix, pt.fmatrix],
otypes=[pt.fmatrix], otypes=[pt.fmatrix],
infer_shape=infer_shape_numpy_add_sub, infer_shape=infer_shape_numpy_add_sub,
......
...@@ -61,10 +61,16 @@ Convert to Variable ...@@ -61,10 +61,16 @@ Convert to Variable
.. autofunction:: pytensor.as_symbolic(...) .. autofunction:: pytensor.as_symbolic(...)
Wrap JAX functions
==================
.. autofunction:: wrap_jax(...)
Alias for :func:`pytensor.link.jax.ops.wrap_jax`
Debug Debug
===== =====
.. autofunction:: pytensor.dprint(...) .. autofunction:: pytensor.dprint(...)
Alias for :func:`pytensor.printing.debugprint` Alias for :func:`pytensor.printing.debugprint`
...@@ -166,7 +166,7 @@ from pytensor.scan import checkpoints ...@@ -166,7 +166,7 @@ from pytensor.scan import checkpoints
from pytensor.scan.basic import scan from pytensor.scan.basic import scan
from pytensor.scan.views import foldl, foldr, map, reduce from pytensor.scan.views import foldl, foldr, map, reduce
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.link.jax.ops import wrap_jax
# isort: on # isort: on
......
...@@ -56,6 +56,7 @@ from pytensor.compile.ops import ( ...@@ -56,6 +56,7 @@ from pytensor.compile.ops import (
register_deep_copy_op_c_code, register_deep_copy_op_c_code,
register_view_op_c_code, register_view_op_c_code,
view_op, view_op,
wrap_py,
) )
from pytensor.compile.profiling import ProfileStats from pytensor.compile.profiling import ProfileStats
from pytensor.compile.sharedvalue import SharedVariable, shared, shared_constructor from pytensor.compile.sharedvalue import SharedVariable, shared, shared_constructor
""" """
This file contains auxiliary Ops, used during the compilation phase and Ops This file contains auxiliary Ops, used during the compilation phase and Ops
building class (:class:`FromFunctionOp`) and decorator (:func:`as_op`) that building class (:class:`FromFunctionOp`) and decorator (:func:`wrap_py`) that
help make new Ops more rapidly. help make new Ops more rapidly.
""" """
...@@ -268,12 +268,12 @@ class FromFunctionOp(Op): ...@@ -268,12 +268,12 @@ class FromFunctionOp(Op):
obj = load_back(mod, name) obj = load_back(mod, name)
except (ImportError, KeyError, AttributeError): except (ImportError, KeyError, AttributeError):
raise pickle.PicklingError( raise pickle.PicklingError(
f"Can't pickle as_op(), not found as {mod}.{name}" f"Can't pickle wrap_py(), not found as {mod}.{name}"
) )
else: else:
if obj is not self: if obj is not self:
raise pickle.PicklingError( raise pickle.PicklingError(
f"Can't pickle as_op(), not the object at {mod}.{name}" f"Can't pickle wrap_py(), not the object at {mod}.{name}"
) )
return load_back, (mod, name) return load_back, (mod, name)
...@@ -282,6 +282,18 @@ class FromFunctionOp(Op): ...@@ -282,6 +282,18 @@ class FromFunctionOp(Op):
def as_op(itypes, otypes, infer_shape=None): def as_op(itypes, otypes, infer_shape=None):
import warnings
warnings.warn(
"pytensor.as_op is deprecated and will be removed in a future release. "
"Please use pytensor.wrap_py instead.",
FutureWarning,
stacklevel=2,
)
return wrap_py(itypes, otypes, infer_shape)
def wrap_py(itypes, otypes, infer_shape=None):
""" """
Decorator that converts a function into a basic PyTensor op that will call Decorator that converts a function into a basic PyTensor op that will call
the supplied function as its implementation. the supplied function as its implementation.
...@@ -301,8 +313,8 @@ def as_op(itypes, otypes, infer_shape=None): ...@@ -301,8 +313,8 @@ def as_op(itypes, otypes, infer_shape=None):
Examples Examples
-------- --------
@as_op(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix], @wrap_py(itypes=[pytensor.tensor.fmatrix, pytensor.tensor.fmatrix],
otypes=[pytensor.tensor.fmatrix]) otypes=[pytensor.tensor.fmatrix])
def numpy_dot(a, b): def numpy_dot(a, b):
return numpy.dot(a, b) return numpy.dot(a, b)
......
...@@ -13,6 +13,7 @@ from pytensor.configdefaults import config ...@@ -13,6 +13,7 @@ from pytensor.configdefaults import config
from pytensor.graph import Constant from pytensor.graph import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.ifelse import IfElse from pytensor.ifelse import IfElse
from pytensor.link.jax.ops import JAXOp
from pytensor.link.utils import fgraph_to_python from pytensor.link.utils import fgraph_to_python
from pytensor.raise_op import CheckAndRaise from pytensor.raise_op import CheckAndRaise
...@@ -142,3 +143,8 @@ def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable: ...@@ -142,3 +143,8 @@ def jax_funcify_OpFromGraph(ofg: OpFromGraph, node=None, **kwargs) -> Callable:
return fgraph_fn(*inputs) return fgraph_fn(*inputs)
return opfromgraph return opfromgraph
@jax_funcify.register(JAXOp)
def jax_op_funcify(op, **kwargs):
return op.perform_jax
差异被折叠。
import pickle import pickle
import numpy as np import numpy as np
import pytest
from pytensor import function from pytensor import function
from pytensor.compile.ops import as_op from pytensor.compile.ops import as_op, wrap_py
from pytensor.tensor.type import dmatrix, dvector from pytensor.tensor.type import dmatrix, dvector
from tests import unittest_tools as utt from tests import unittest_tools as utt
@as_op([dmatrix, dmatrix], dmatrix) @wrap_py([dmatrix, dmatrix], dmatrix)
def mul(a, b): def mul(a, b):
""" """
This is for test_pickle, since the function still has to be This is for test_pickle, since the function still has to be
...@@ -21,7 +22,7 @@ class TestOpDecorator(utt.InferShapeTester): ...@@ -21,7 +22,7 @@ class TestOpDecorator(utt.InferShapeTester):
def test_1arg(self): def test_1arg(self):
x = dmatrix("x") x = dmatrix("x")
@as_op(dmatrix, dvector) @wrap_py(dmatrix, dvector)
def cumprod(x): def cumprod(x):
return np.cumprod(x) return np.cumprod(x)
...@@ -31,13 +32,28 @@ class TestOpDecorator(utt.InferShapeTester): ...@@ -31,13 +32,28 @@ class TestOpDecorator(utt.InferShapeTester):
assert np.allclose(r, r0), (r, r0) assert np.allclose(r, r0), (r, r0)
def test_deprecation(self):
x = dmatrix("x")
with pytest.warns(FutureWarning):
@as_op(dmatrix, dvector)
def cumprod(x):
return np.cumprod(x)
fn = function([x], cumprod(x))
r = fn([[1.5, 5], [2, 2]])
r0 = np.array([1.5, 7.5, 15.0, 30.0])
assert np.allclose(r, r0), (r, r0)
def test_2arg(self): def test_2arg(self):
x = dmatrix("x") x = dmatrix("x")
x.tag.test_value = np.zeros((2, 2)) x.tag.test_value = np.zeros((2, 2))
y = dvector("y") y = dvector("y")
y.tag.test_value = [0, 0, 0, 0] y.tag.test_value = [0, 0, 0, 0]
@as_op([dmatrix, dvector], dvector) @wrap_py([dmatrix, dvector], dvector)
def cumprod_plus(x, y): def cumprod_plus(x, y):
return np.cumprod(x) + y return np.cumprod(x) + y
...@@ -57,7 +73,7 @@ class TestOpDecorator(utt.InferShapeTester): ...@@ -57,7 +73,7 @@ class TestOpDecorator(utt.InferShapeTester):
x, y = shapes x, y = shapes
return [y] return [y]
@as_op([dmatrix, dvector], dvector, infer_shape) @wrap_py([dmatrix, dvector], dvector, infer_shape)
def cumprod_plus(x, y): def cumprod_plus(x, y):
return np.cumprod(x) + y return np.cumprod(x) + y
......
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论