提交 a1fcb77c authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cleanup JAX Scalar dispatch

上级 2c03ecfb
......@@ -37,7 +37,7 @@ def try_import_tfp_jax_op(op: ScalarOp, jax_op_name: Optional[str] = None) -> Ca
return typing.cast(Callable, getattr(tfp_jax_math, jax_op_name))
def check_if_inputs_scalars(node):
def all_inputs_are_scalar(node):
"""Check whether all the inputs of an `Elemwise` are scalar values.
`jax.lax` or `jax.numpy` functions systematically return `TracedArrays`,
......@@ -62,54 +62,68 @@ def check_if_inputs_scalars(node):
@jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op, node, **kwargs):
"""Return JAX function that implements the same computation as the Scalar Op.
This dispatch is expected to return a JAX function that works on Array inputs as Elemwise does,
even though it's dispatched on the Scalar Op.
"""
# We dispatch some PyTensor operators to Python operators
# whenever the inputs are all scalars.
are_inputs_scalars = check_if_inputs_scalars(node)
if are_inputs_scalars:
elemwise = elemwise_scalar(op)
if elemwise is not None:
return elemwise
func_name = op.nfunc_spec[0]
if all_inputs_are_scalar(node):
jax_func = jax_funcify_scalar_op_via_py_operators(op)
if jax_func is not None:
return jax_func
nfunc_spec = getattr(op, "nfunc_spec", None)
if nfunc_spec is None:
raise NotImplementedError(f"Dispatch not implemented for Scalar Op {op}")
func_name = nfunc_spec[0]
if "." in func_name:
jnp_func = functools.reduce(getattr, [jax] + func_name.split("."))
else:
jnp_func = getattr(jnp, func_name)
if hasattr(op, "nfunc_variadic"):
# These are special cases that handle invalid arities due to the broken
# PyTensor `Op` type contract (e.g. binary `Op`s that also function as
# their own variadic counterparts--even when those counterparts already
# exist as independent `Op`s).
jax_variadic_func = getattr(jnp, op.nfunc_variadic)
def elemwise(*args):
if len(args) > op.nfunc_spec[1]:
return jax_variadic_func(
jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0
)
else:
return jnp_func(*args)
return elemwise
jax_func = functools.reduce(getattr, [jax] + func_name.split("."))
else:
return jnp_func
jax_func = getattr(jnp, func_name)
if len(node.inputs) > op.nfunc_spec[1]:
# Some Scalar Ops accept multiple number of inputs, behaving as a variadic function,
# even though the base Op from `func_name` is specified as a binary Op.
# This happens with `Add`, which can work as a `Sum` for multiple scalars.
jax_variadic_func = getattr(jnp, op.nfunc_variadic, None)
if not jax_variadic_func:
raise NotImplementedError(
f"Dispatch not implemented for Scalar Op {op} with {len(node.inputs)} inputs"
)
def jax_func(*args):
return jax_variadic_func(
jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0
)
return jax_func
@functools.singledispatch
def elemwise_scalar(op):
def jax_funcify_scalar_op_via_py_operators(op):
"""Specialized JAX dispatch for Elemwise operations where all inputs are Scalar arrays.
Scalar (constant) arrays in the JAX backend get lowered to the native types (int, floats),
which can perform better with Python operators, and more importantly, avoid upcasting to array types
not supported by some JAX functions.
"""
return None
@elemwise_scalar.register(Add)
def elemwise_scalar_add(op):
@jax_funcify_scalar_op_via_py_operators.register(Add)
def jax_funcify_scalar_Add(op):
def elemwise(*inputs):
return sum(inputs)
return elemwise
@elemwise_scalar.register(Mul)
def elemwise_scalar_mul(op):
@jax_funcify_scalar_op_via_py_operators.register(Mul)
def jax_funcify_scalar_Mul(op):
import operator
from functools import reduce
......@@ -119,24 +133,24 @@ def elemwise_scalar_mul(op):
return elemwise
@elemwise_scalar.register(Sub)
def elemwise_scalar_sub(op):
@jax_funcify_scalar_op_via_py_operators.register(Sub)
def jax_funcify_scalar_Sub(op):
def elemwise(x, y):
return x - y
return elemwise
@elemwise_scalar.register(IntDiv)
def elemwise_scalar_intdiv(op):
@jax_funcify_scalar_op_via_py_operators.register(IntDiv)
def jax_funcify_scalar_IntDiv(op):
def elemwise(x, y):
return x // y
return elemwise
@elemwise_scalar.register(Mod)
def elemwise_scalar_mod(op):
@jax_funcify_scalar_op_via_py_operators.register(Mod)
def jax_funcify_scalar_Mod(op):
def elemwise(x, y):
return x % y
......
......@@ -23,6 +23,7 @@ from pytensor.tensor.math import (
psi,
sigmoid,
softplus,
tri_gamma,
)
from pytensor.tensor.type import matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论