提交 f35ce262 authored 作者: HangenYuu's avatar HangenYuu 提交者: Ricardo Vieira

Reorganized JAX link folder structure

上级 308bc019
...@@ -2,18 +2,20 @@ ...@@ -2,18 +2,20 @@
from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify from pytensor.link.jax.dispatch.basic import jax_funcify, jax_typify
# Load dispatch specializations # Load dispatch specializations
import pytensor.link.jax.dispatch.scalar import pytensor.link.jax.dispatch.blas
import pytensor.link.jax.dispatch.tensor_basic import pytensor.link.jax.dispatch.blockwise
import pytensor.link.jax.dispatch.subtensor import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.shape
import pytensor.link.jax.dispatch.extra_ops import pytensor.link.jax.dispatch.extra_ops
import pytensor.link.jax.dispatch.math
import pytensor.link.jax.dispatch.nlinalg import pytensor.link.jax.dispatch.nlinalg
import pytensor.link.jax.dispatch.slinalg
import pytensor.link.jax.dispatch.random import pytensor.link.jax.dispatch.random
import pytensor.link.jax.dispatch.elemwise import pytensor.link.jax.dispatch.scalar
import pytensor.link.jax.dispatch.scan import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.sparse import pytensor.link.jax.dispatch.shape
import pytensor.link.jax.dispatch.blockwise import pytensor.link.jax.dispatch.slinalg
import pytensor.link.jax.dispatch.sort import pytensor.link.jax.dispatch.sort
import pytensor.link.jax.dispatch.sparse
import pytensor.link.jax.dispatch.subtensor
import pytensor.link.jax.dispatch.tensor_basic
# isort: on # isort: on
import jax.numpy as jnp
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot
@jax_funcify.register(BatchedDot)
def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match along the first dimension of BatchedDot")
return jnp.matmul(a, b)
return batched_dot
import jax.numpy as jnp
import numpy as np
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.math import Argmax, Dot, Max
@jax_funcify.register(Dot)
def jax_funcify_Dot(op, **kwargs):
def dot(x, y):
return jnp.dot(x, y)
return dot
@jax_funcify.register(Max)
def jax_funcify_Max(op, **kwargs):
axis = op.axis
def max(x):
max_res = jnp.max(x, axis)
return max_res
return max
@jax_funcify.register(Argmax)
def jax_funcify_Argmax(op, **kwargs):
axis = op.axis
def argmax(x):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
# Not-reduced axes in front
transposed_x = jnp.transpose(
x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = (
*kept_shape,
np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
)
reshaped_x = transposed_x.reshape(tuple(new_shape))
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
return max_idx_res
return argmax
import jax.numpy as jnp import jax.numpy as jnp
import numpy as np
from pytensor.link.jax.dispatch import jax_funcify from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Argmax, Dot, Max
from pytensor.tensor.nlinalg import ( from pytensor.tensor.nlinalg import (
SVD, SVD,
Det, Det,
...@@ -80,14 +77,6 @@ def jax_funcify_QRFull(op, **kwargs): ...@@ -80,14 +77,6 @@ def jax_funcify_QRFull(op, **kwargs):
return qr_full return qr_full
@jax_funcify.register(Dot)
def jax_funcify_Dot(op, **kwargs):
def dot(x, y):
return jnp.dot(x, y)
return dot
@jax_funcify.register(MatrixPinv) @jax_funcify.register(MatrixPinv)
def jax_funcify_Pinv(op, **kwargs): def jax_funcify_Pinv(op, **kwargs):
def pinv(x): def pinv(x):
...@@ -96,66 +85,9 @@ def jax_funcify_Pinv(op, **kwargs): ...@@ -96,66 +85,9 @@ def jax_funcify_Pinv(op, **kwargs):
return pinv return pinv
@jax_funcify.register(BatchedDot)
def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension")
return jnp.matmul(a, b)
return batched_dot
@jax_funcify.register(KroneckerProduct) @jax_funcify.register(KroneckerProduct)
def jax_funcify_KroneckerProduct(op, **kwargs): def jax_funcify_KroneckerProduct(op, **kwargs):
def _kron(x, y): def _kron(x, y):
return jnp.kron(x, y) return jnp.kron(x, y)
return _kron return _kron
@jax_funcify.register(Max)
def jax_funcify_Max(op, **kwargs):
axis = op.axis
def max(x):
max_res = jnp.max(x, axis)
return max_res
return max
@jax_funcify.register(Argmax)
def jax_funcify_Argmax(op, **kwargs):
axis = op.axis
def argmax(x):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = np.array([i for i in range(x.ndim) if i not in axes], dtype="int64")
# Not-reduced axes in front
transposed_x = jnp.transpose(
x, tuple(np.concatenate((keep_axes, np.array(axes, dtype="int64"))))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = (
*kept_shape,
np.prod(np.array(reduced_shape, dtype="int64"), dtype="int64"),
)
reshaped_x = transposed_x.reshape(tuple(new_shape))
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
return max_idx_res
return argmax
import numpy as np
import pytest
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas
from pytensor.tensor.type import tensor3
from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_BatchedDot():
# tensor3 . tensor3
a = tensor3("a")
a.tag.test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
)
b = tensor3("b")
b.tag.test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
)
out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
with pytest.raises(TypeError):
pytensor_jax_fn(*inputs)
import numpy as np
import pytest
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.math import Argmax, Max, maximum
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import dvector, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
def test_jax_max_and_argmax():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = dvector()
mx = Max([0])(x)
amx = Argmax([0])(x)
out = mx * amx
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.r_[1, 2]])
def test_dot():
y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
x = vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
A = matrix("A")
A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
alpha = scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
beta = scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = maximum(y, x)
fgraph = FunctionGraph([y, x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = pt_max(y)
fgraph = FunctionGraph([y], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
...@@ -2,46 +2,16 @@ import numpy as np ...@@ -2,46 +2,16 @@ import numpy as np
import pytest import pytest
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.link.jax import JAXLinker
from pytensor.tensor import blas as pt_blas
from pytensor.tensor import nlinalg as pt_nlinalg from pytensor.tensor import nlinalg as pt_nlinalg
from pytensor.tensor.math import Argmax, Max, maximum from pytensor.tensor.type import matrix
from pytensor.tensor.math import max as pt_max
from pytensor.tensor.type import dvector, matrix, scalar, tensor3, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
def test_jax_BatchedDot():
# tensor3 . tensor3
a = tensor3("a")
a.tag.test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
)
b = tensor3("b")
b.tag.test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
)
out = pt_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
pytensor_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
with pytest.raises(TypeError):
pytensor_jax_fn(*inputs)
def test_jax_basic_multiout(): def test_jax_basic_multiout():
rng = np.random.default_rng(213234) rng = np.random.default_rng(213234)
...@@ -79,45 +49,6 @@ def test_jax_basic_multiout(): ...@@ -79,45 +49,6 @@ def test_jax_basic_multiout():
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn) compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
def test_jax_max_and_argmax():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = dvector()
mx = Max([0])(x)
amx = Argmax([0])(x)
out = mx * amx
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.r_[1, 2]])
def test_tensor_basics():
y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
x = vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
A = matrix("A")
A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
alpha = scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
beta = scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = maximum(y, x)
fgraph = FunctionGraph([y, x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = pt_max(y)
fgraph = FunctionGraph([y], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_pinv(): def test_pinv():
x = matrix("x") x = matrix("x")
x_inv = pt_nlinalg.pinv(x) x_inv = pt_nlinalg.pinv(x)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论