提交 3b425ec7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Implement JAX dot with sparse constants

Non-constant sparse inputs can't be handled because JAX does not allow Scipy sparse matrices as inputs. We could implement a BCOO type explicitly but this would be JAX exclusive, and the user would need to use it from the get go, meaning such graphs would not be compatible with other backends.
上级 98de2462
...@@ -12,5 +12,6 @@ import pytensor.link.jax.dispatch.slinalg ...@@ -12,5 +12,6 @@ import pytensor.link.jax.dispatch.slinalg
import pytensor.link.jax.dispatch.random import pytensor.link.jax.dispatch.random
import pytensor.link.jax.dispatch.elemwise import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.scan import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.sparse
# isort: on # isort: on
import jax.experimental.sparse as jsp
from scipy.sparse import spmatrix
from pytensor.graph.basic import Constant
from pytensor.link.jax.dispatch import jax_funcify, jax_typify
from pytensor.sparse.basic import Dot, StructuredDot
from pytensor.sparse.type import SparseTensorType
@jax_typify.register(spmatrix)
def jax_typify_spmatrix(matrix, dtype=None, **kwargs):
# Note: This changes the type of the constants from CSR/CSC to BCOO
# We could add BCOO as a PyTensor type but this would only be useful for JAX graphs
# and it would break the premise of one graph -> multiple backends.
# The same situation happens with RandomGenerators...
return jsp.BCOO.from_scipy_sparse(matrix)
@jax_funcify.register(Dot)
@jax_funcify.register(StructuredDot)
def jax_funcify_sparse_dot(op, node, **kwargs):
for input in node.inputs:
if isinstance(input.type, SparseTensorType) and not isinstance(input, Constant):
raise NotImplementedError(
"JAX sparse dot only implemented for constant sparse inputs"
)
if isinstance(node.outputs[0].type, SparseTensorType):
raise NotImplementedError("JAX sparse dot only implemented for dense outputs")
@jsp.sparsify
def sparse_dot(x, y):
out = x @ y
if isinstance(out, jsp.BCOO):
out = out.todense()
return out
return sparse_dot
import numpy as np
import pytest
import scipy.sparse
import pytensor.sparse as ps
import pytensor.tensor as pt
from pytensor import function
from pytensor.graph import FunctionGraph
from tests.link.jax.test_basic import compare_jax_and_py
@pytest.mark.parametrize(
"op, x_type, y_type",
[
(ps.dot, pt.vector, ps.matrix),
(ps.dot, pt.matrix, ps.matrix),
(ps.dot, ps.matrix, pt.vector),
(ps.dot, ps.matrix, pt.matrix),
# structured_dot only allows matrix @ matrix
(ps.structured_dot, pt.matrix, ps.matrix),
(ps.structured_dot, ps.matrix, pt.matrix),
],
)
def test_sparse_dot_constant_sparse(x_type, y_type, op):
inputs = []
test_values = []
if x_type is ps.matrix:
x_sp = scipy.sparse.random(5, 40, density=0.25, format="csr", dtype="float32")
x_pt = ps.as_sparse_variable(x_sp, name="x")
else:
x_pt = x_type("x", dtype="float32")
if x_pt.ndim == 1:
x_test = np.arange(40, dtype="float32")
else:
x_test = np.arange(5 * 40, dtype="float32").reshape(5, 40)
inputs.append(x_pt)
test_values.append(x_test)
if y_type is ps.matrix:
y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
y_pt = ps.as_sparse_variable(y_sp, name="y")
else:
y_pt = y_type("y", dtype="float32")
if y_pt.ndim == 1:
y_test = np.arange(40, dtype="float32")
else:
y_test = np.arange(40 * 3, dtype="float32").reshape(40, 3)
inputs.append(y_pt)
test_values.append(y_test)
dot_pt = op(x_pt, y_pt)
fgraph = FunctionGraph(inputs, [dot_pt])
compare_jax_and_py(fgraph, test_values)
def test_sparse_dot_non_const_raises():
x_pt = pt.vector("x")
y_sp = scipy.sparse.random(40, 3, density=0.25, format="csc", dtype="float32")
y_pt = ps.as_sparse_variable(y_sp, name="y").type()
out = ps.dot(x_pt, y_pt)
msg = "JAX sparse dot only implemented for constant sparse inputs"
with pytest.raises(NotImplementedError, match=msg):
function([x_pt, y_pt], out, mode="JAX")
y_pt_shared = ps.shared(y_sp, name="y")
out = ps.dot(x_pt, y_pt_shared)
with pytest.raises(NotImplementedError, match=msg):
function([x_pt], out, mode="JAX")
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论