提交 613ccaf3 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Support Blockwise in JAX backend

上级 a5d54c8e
......@@ -13,5 +13,6 @@ import pytensor.link.jax.dispatch.random
import pytensor.link.jax.dispatch.elemwise
import pytensor.link.jax.dispatch.scan
import pytensor.link.jax.dispatch.sparse
import pytensor.link.jax.dispatch.blockwise
# isort: on
import jax.numpy as jnp
from pytensor.graph import FunctionGraph
from pytensor.link.jax.dispatch import jax_funcify
from pytensor.tensor.blockwise import Blockwise
@jax_funcify.register(Blockwise)
def funcify_Blockwise(op: Blockwise, node, *args, **kwargs):
signature = op.signature
core_node = op._create_dummy_core_node(node.inputs)
core_fgraph = FunctionGraph(inputs=core_node.inputs, outputs=core_node.outputs)
tuple_core_fn = jax_funcify(core_fgraph)
if len(node.outputs) == 1:
def core_fn(*inputs):
return tuple_core_fn(*inputs)[0]
else:
core_fn = tuple_core_fn
vect_fn = jnp.vectorize(core_fn, signature=signature)
def blockwise_fn(*inputs):
op._check_runtime_broadcast(node, inputs)
return vect_fn(*inputs)
return blockwise_fn
import numpy as np
import pytest
from pytensor import config
from pytensor.graph import FunctionGraph
from pytensor.tensor import tensor
from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.math import Dot, matmul
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_blockwise import check_blockwise_runtime_broadcasting
jax = pytest.importorskip("jax")
def test_runtime_broadcasting():
check_blockwise_runtime_broadcasting("JAX")
# Equivalent blockwise to matmul but with dumb signature
odd_matmul = Blockwise(Dot(), signature="(i00,i01),(i10,i11)->(o00,o01)")
@pytest.mark.parametrize("matmul_op", (matmul, odd_matmul))
def test_matmul(matmul_op):
rng = np.random.default_rng(14)
a = tensor("a", shape=(2, 3, 5))
b = tensor("b", shape=(2, 5, 3))
test_values = [
rng.normal(size=(inp.type.shape)).astype(config.floatX) for inp in (a, b)
]
out = matmul_op(a, b)
assert isinstance(out.owner.op, Blockwise)
fg = FunctionGraph([a, b], [out])
fn, _ = compare_jax_and_py(fg, test_values)
# Check we are not adding any unnecessary stuff
jaxpr = str(jax.make_jaxpr(fn.vm.jit_fn)(*test_values))
jaxpr = jaxpr.replace("name=jax_funcified_fgraph", "name=matmul")
expected_jaxpr = str(jax.make_jaxpr(jax.jit(jax.numpy.matmul))(*test_values))
assert jaxpr == expected_jaxpr
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论