提交 74db7999 authored 作者: Rémi Louf's avatar Rémi Louf 提交者: Brandon T. Willard

Refactor the JAX dispatcher

上级 1d899ee9
差异被折叠。
# isort: off
from aesara.link.jax.dispatch.basic import jax_funcify, jax_typify
# Load dispatch specializations
import aesara.link.jax.dispatch.scalar
import aesara.link.jax.dispatch.tensor_basic
import aesara.link.jax.dispatch.subtensor
import aesara.link.jax.dispatch.shape
import aesara.link.jax.dispatch.extra_ops
import aesara.link.jax.dispatch.nlinalg
import aesara.link.jax.dispatch.slinalg
import aesara.link.jax.dispatch.random
import aesara.link.jax.dispatch.elemwise
import aesara.link.jax.dispatch.scan
# isort: on
import warnings
from functools import singledispatch
import jax
import jax.numpy as jnp
import numpy as np
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.ifelse import IfElse
from aesara.link.utils import fgraph_to_python
from aesara.raise_op import CheckAndRaise
if config.floatX == "float64":
jax.config.update("jax_enable_x64", True)
else:
jax.config.update("jax_enable_x64", False)
# XXX: Enabling this will break some shape-based functionality, and severely
# limit the types of graphs that can be converted.
# See https://github.com/google/jax/blob/4d556837cc9003492f674c012689efc3d68fdf5f/design_notes/omnistaging.md
# Older versions < 0.2.0 do not have this flag so we don't need to set it.
try:
jax.config.disable_omnistaging()
except AttributeError:
pass
except Exception as e:
# The version might be >= 0.2.12, which means that omnistaging can't be
# disabled
warnings.warn(f"JAX omnistaging couldn't be disabled: {e}")
@singledispatch
def jax_typify(data, dtype=None, **kwargs):
r"""Convert instances of Aesara `Type`\s to JAX types."""
if dtype is None:
return data
else:
return jnp.array(data, dtype=dtype)
@jax_typify.register(np.ndarray)
def jax_typify_ndarray(data, dtype=None, **kwargs):
return jnp.array(data, dtype=dtype)
@singledispatch
def jax_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a JAX compatible function from an Aesara `Op`."""
raise NotImplementedError(f"No JAX conversion for the given `Op`: {op}")
@jax_funcify.register(FunctionGraph)
def jax_funcify_FunctionGraph(
fgraph,
node=None,
fgraph_name="jax_funcified_fgraph",
**kwargs,
):
return fgraph_to_python(
fgraph,
jax_funcify,
type_conversion_fn=jax_typify,
fgraph_name=fgraph_name,
**kwargs,
)
@jax_funcify.register(IfElse)
def jax_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
def ifelse(cond, *args, n_outs=n_outs):
res = jax.lax.cond(
cond, lambda _: args[:n_outs], lambda _: args[n_outs:], operand=None
)
return res if n_outs > 1 else res[0]
return ifelse
@jax_funcify.register(CheckAndRaise)
def jax_funcify_CheckAndRaise(op, **kwargs):
raise NotImplementedError(
f"""This exception is raised because you tried to convert an aesara graph with a `CheckAndRaise` Op (message: {op.msg}) to JAX.
JAX uses tracing to jit-compile functions, and assertions typically
don't do well with tracing. The appropriate workaround depends on what
you intended to do with the assertions in the first place.
Note that all assertions can be removed from the graph by adding
`local_remove_all_assert` to the rewrites."""
)
def jnp_safe_copy(x):
try:
res = jnp.copy(x)
except NotImplementedError:
warnings.warn(
"`jnp.copy` is not implemented yet. " "Using the object's `copy` method."
)
if hasattr(x, "copy"):
res = jnp.array(x.copy())
else:
warnings.warn(f"Object has no `copy` method: {x}")
res = x
return res
@jax_funcify.register(DeepCopyOp)
def jax_funcify_DeepCopyOp(op, **kwargs):
def deepcopyop(x):
return jnp_safe_copy(x)
return deepcopyop
@jax_funcify.register(ViewOp)
def jax_funcify_ViewOp(op, **kwargs):
def viewop(x):
return x
return viewop
import jax
import jax.numpy as jnp
from aesara.link.jax.dispatch.basic import jax_funcify, jnp_safe_copy
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.nnet.basic import LogSoftmax, Softmax, SoftmaxGrad
@jax_funcify.register(Elemwise)
def jax_funcify_Elemwise(op, **kwargs):
scalar_op = op.scalar_op
return jax_funcify(scalar_op, **kwargs)
@jax_funcify.register(CAReduce)
def jax_funcify_CAReduce(op, **kwargs):
axis = op.axis
op_nfunc_spec = getattr(op, "nfunc_spec", None)
scalar_nfunc_spec = getattr(op.scalar_op, "nfunc_spec", None)
scalar_op_name = getattr(op.scalar_op, "name", None)
scalar_op_identity = getattr(op.scalar_op, "identity", None)
acc_dtype = getattr(op, "acc_dtype", None)
def careduce(x):
nonlocal axis, op_nfunc_spec, scalar_nfunc_spec, scalar_op_name, scalar_op_identity, acc_dtype
if axis is None:
axis = list(range(x.ndim))
if acc_dtype is None:
acc_dtype = x.dtype.type
if op_nfunc_spec:
jax_op = getattr(jnp, op_nfunc_spec[0])
return jax_op(x, axis=axis).astype(acc_dtype)
# The Aesara `Op` didn't tell us which NumPy equivalent to use (or
# there isn't one), so we use this fallback approach
if scalar_nfunc_spec:
scalar_fn_name = scalar_nfunc_spec[0]
elif scalar_op_name:
scalar_fn_name = scalar_op_name
to_reduce = reversed(sorted(axis))
if to_reduce:
# In this case, we need to use the `jax.lax` function (if there
# is one), and not the `jnp` version.
jax_op = getattr(jax.lax, scalar_fn_name)
init_value = jnp.array(scalar_op_identity, dtype=acc_dtype)
return jax.lax.reduce(x, init_value, jax_op, to_reduce).astype(acc_dtype)
else:
return x
return careduce
@jax_funcify.register(DimShuffle)
def jax_funcify_DimShuffle(op, **kwargs):
def dimshuffle(x):
res = jnp.transpose(x, op.transposition)
shape = list(res.shape[: len(op.shuffle)])
for augm in op.augment:
shape.insert(augm, 1)
res = jnp.reshape(res, shape)
if not op.inplace:
res = jnp_safe_copy(res)
return res
return dimshuffle
@jax_funcify.register(Softmax)
def jax_funcify_Softmax(op, **kwargs):
axis = op.axis
def softmax(x):
return jax.nn.softmax(x, axis=axis)
return softmax
@jax_funcify.register(SoftmaxGrad)
def jax_funcify_SoftmaxGrad(op, **kwargs):
axis = op.axis
def softmax_grad(dy, sm):
dy_times_sm = dy * sm
return dy_times_sm - jnp.sum(dy_times_sm, axis=axis, keepdims=True) * sm
return softmax_grad
@jax_funcify.register(LogSoftmax)
def jax_funcify_LogSoftmax(op, **kwargs):
axis = op.axis
def log_softmax(x):
return jax.nn.log_softmax(x, axis=axis)
return log_softmax
import warnings
import jax.numpy as jnp
from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.tensor.extra_ops import (
Bartlett,
BroadcastTo,
CumOp,
FillDiagonal,
FillDiagonalOffset,
RavelMultiIndex,
Repeat,
Unique,
UnravelIndex,
)
@jax_funcify.register(Bartlett)
def jax_funcify_Bartlett(op, **kwargs):
def bartlett(x):
return jnp.bartlett(x)
return bartlett
@jax_funcify.register(CumOp)
def jax_funcify_CumOp(op, **kwargs):
axis = op.axis
mode = op.mode
def cumop(x, axis=axis, mode=mode):
if mode == "add":
return jnp.cumsum(x, axis=axis)
else:
return jnp.cumprod(x, axis=axis)
return cumop
@jax_funcify.register(Repeat)
def jax_funcify_Repeat(op, **kwargs):
axis = op.axis
def repeatop(x, repeats, axis=axis):
return jnp.repeat(x, repeats, axis=axis)
return repeatop
@jax_funcify.register(Unique)
def jax_funcify_Unique(op, **kwargs):
axis = op.axis
if axis is not None:
raise NotImplementedError(
"jax.numpy.unique is not implemented for the axis argument"
)
return_index = op.return_index
return_inverse = op.return_inverse
return_counts = op.return_counts
def unique(
x,
return_index=return_index,
return_inverse=return_inverse,
return_counts=return_counts,
axis=axis,
):
ret = jnp.lax_numpy._unique1d(x, return_index, return_inverse, return_counts)
if len(ret) == 1:
return ret[0]
else:
return ret
return unique
@jax_funcify.register(UnravelIndex)
def jax_funcify_UnravelIndex(op, **kwargs):
order = op.order
warnings.warn("JAX ignores the `order` parameter in `unravel_index`.")
def unravelindex(indices, dims, order=order):
return jnp.unravel_index(indices, dims)
return unravelindex
@jax_funcify.register(RavelMultiIndex)
def jax_funcify_RavelMultiIndex(op, **kwargs):
mode = op.mode
order = op.order
def ravelmultiindex(*inp, mode=mode, order=order):
multi_index, dims = inp[:-1], inp[-1]
return jnp.ravel_multi_index(multi_index, dims, mode=mode, order=order)
return ravelmultiindex
@jax_funcify.register(BroadcastTo)
def jax_funcify_BroadcastTo(op, **kwargs):
def broadcast_to(x, *shape):
return jnp.broadcast_to(x, shape)
return broadcast_to
@jax_funcify.register(FillDiagonal)
def jax_funcify_FillDiagonal(op, **kwargs):
def filldiagonal(value, diagonal):
i, j = jnp.diag_indices(min(value.shape[-2:]))
return value.at[..., i, j].set(diagonal)
return filldiagonal
@jax_funcify.register(FillDiagonalOffset)
def jax_funcify_FillDiagonalOffset(op, **kwargs):
# def filldiagonaloffset(a, val, offset):
# height, width = a.shape
#
# if offset >= 0:
# start = offset
# num_of_step = min(min(width, height), width - offset)
# else:
# start = -offset * a.shape[1]
# num_of_step = min(min(width, height), height + offset)
#
# step = a.shape[1] + 1
# end = start + step * num_of_step
# a.flat[start:end:step] = val
#
# return a
#
# return filldiagonaloffset
raise NotImplementedError("flatiter not implemented in JAX")
import jax.numpy as jnp
from aesara.link.jax.dispatch import jax_funcify
from aesara.tensor.blas import BatchedDot
from aesara.tensor.math import Dot, MaxAndArgmax
from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
@jax_funcify.register(SVD)
def jax_funcify_SVD(op, **kwargs):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
def svd(x, full_matrices=full_matrices, compute_uv=compute_uv):
return jnp.linalg.svd(x, full_matrices=full_matrices, compute_uv=compute_uv)
return svd
@jax_funcify.register(Det)
def jax_funcify_Det(op, **kwargs):
def det(x):
return jnp.linalg.det(x)
return det
@jax_funcify.register(Eig)
def jax_funcify_Eig(op, **kwargs):
def eig(x):
return jnp.linalg.eig(x)
return eig
@jax_funcify.register(Eigh)
def jax_funcify_Eigh(op, **kwargs):
uplo = op.UPLO
def eigh(x, uplo=uplo):
return jnp.linalg.eigh(x, UPLO=uplo)
return eigh
@jax_funcify.register(MatrixInverse)
def jax_funcify_MatrixInverse(op, **kwargs):
def matrix_inverse(x):
return jnp.linalg.inv(x)
return matrix_inverse
@jax_funcify.register(QRFull)
def jax_funcify_QRFull(op, **kwargs):
mode = op.mode
def qr_full(x, mode=mode):
return jnp.linalg.qr(x, mode=mode)
return qr_full
@jax_funcify.register(Dot)
def jax_funcify_Dot(op, **kwargs):
def dot(x, y):
return jnp.dot(x, y)
return dot
@jax_funcify.register(BatchedDot)
def jax_funcify_BatchedDot(op, **kwargs):
def batched_dot(a, b):
if a.shape[0] != b.shape[0]:
raise TypeError("Shapes must match in the 0-th dimension")
if a.ndim == 2 or b.ndim == 2:
return jnp.einsum("n...j,nj...->n...", a, b)
return jnp.einsum("nij,njk->nik", a, b)
return batched_dot
@jax_funcify.register(MaxAndArgmax)
def jax_funcify_MaxAndArgmax(op, **kwargs):
axis = op.axis
def maxandargmax(x, axis=axis):
if axis is None:
axes = tuple(range(x.ndim))
else:
axes = tuple(int(ax) for ax in axis)
max_res = jnp.max(x, axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = jnp.array(
[i for i in range(x.ndim) if i not in axes], dtype="int64"
)
# Not-reduced axes in front
transposed_x = jnp.transpose(
x, jnp.concatenate((keep_axes, jnp.array(axes, dtype="int64")))
)
kept_shape = transposed_x.shape[: len(keep_axes)]
reduced_shape = transposed_x.shape[len(keep_axes) :]
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = kept_shape + (
jnp.prod(jnp.array(reduced_shape, dtype="int64"), dtype="int64"),
)
reshaped_x = transposed_x.reshape(new_shape)
max_idx_res = jnp.argmax(reshaped_x, axis=-1).astype("int64")
return max_res, max_idx_res
return maxandargmax
import jax
import jax.numpy as jnp
from numpy.random import Generator, RandomState
from numpy.random.bit_generator import _coerce_to_uint32_array
from aesara.link.jax.dispatch.basic import jax_funcify, jax_typify
from aesara.tensor.random.op import RandomVariable
numpy_bit_gens = {"MT19937": 0, "PCG64": 1, "Philox": 2, "SFC64": 3}
@jax_typify.register(RandomState)
def jax_typify_RandomState(state, **kwargs):
state = state.get_state(legacy=False)
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = state["state"]["key"][0:2]
return state
@jax_typify.register(Generator)
def jax_typify_Generator(rng, **kwargs):
state = rng.__getstate__()
state["bit_generator"] = numpy_bit_gens[state["bit_generator"]]
# XXX: Is this a reasonable approach?
state["jax_state"] = _coerce_to_uint32_array(state["state"]["state"])[0:2]
# The "state" and "inc" values in a NumPy `Generator` are 128 bits, which
# JAX can't handle, so we split these values into arrays of 32 bit integers
# and then combine the first two into a single 64 bit integers.
#
# XXX: Depending on how we expect these values to be used, is this approach
# reasonable?
#
# TODO: We might as well remove these altogether, since this conversion
# should only occur once (e.g. when the graph is converted/JAX-compiled),
# and, from then on, we use the custom "jax_state" value.
inc_32 = _coerce_to_uint32_array(state["state"]["inc"])
state_32 = _coerce_to_uint32_array(state["state"]["state"])
state["state"]["inc"] = inc_32[0] << 32 | inc_32[1]
state["state"]["state"] = state_32[0] << 32 | state_32[1]
return state
@jax_funcify.register(RandomVariable)
def jax_funcify_RandomVariable(op, node, **kwargs):
name = op.name
# TODO Make sure there's a 1-to-1 correspondance with names
if not hasattr(jax.random, name):
raise NotImplementedError(
f"No JAX conversion for the given distribution: {name}"
)
dtype = node.outputs[1].dtype
def random_variable(rng, size, dtype_num, *args):
if not op.inplace:
rng = rng.copy()
prng = rng["jax_state"]
data = getattr(jax.random, name)(key=prng, shape=size)
smpl_value = jnp.array(data, dtype=dtype)
rng["jax_state"] = jax.random.split(prng, num=1)[0]
return (rng, smpl_value)
return random_variable
import functools
import jax
import jax.numpy as jnp
from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scalar.math import Erf, Erfc, Erfinv, Log1mexp, Psi
@jax_funcify.register(ScalarOp)
def jax_funcify_ScalarOp(op, **kwargs):
func_name = op.nfunc_spec[0]
if "." in func_name:
jnp_func = functools.reduce(getattr, [jax] + func_name.split("."))
else:
jnp_func = getattr(jnp, func_name)
if hasattr(op, "nfunc_variadic"):
# These are special cases that handle invalid arities due to the broken
# Aesara `Op` type contract (e.g. binary `Op`s that also function as
# their own variadic counterparts--even when those counterparts already
# exist as independent `Op`s).
jax_variadic_func = getattr(jnp, op.nfunc_variadic)
def elemwise(*args):
if len(args) > op.nfunc_spec[1]:
return jax_variadic_func(
jnp.stack(jnp.broadcast_arrays(*args), axis=0), axis=0
)
else:
return jnp_func(*args)
return elemwise
else:
return jnp_func
@jax_funcify.register(Cast)
def jax_funcify_Cast(op, **kwargs):
def cast(x):
return jnp.array(x).astype(op.o_type.dtype)
return cast
@jax_funcify.register(Identity)
def jax_funcify_Identity(op, **kwargs):
def identity(x):
return x
return identity
@jax_funcify.register(Clip)
def jax_funcify_Clip(op, **kwargs):
def clip(x, min, max):
return jnp.where(x < min, min, jnp.where(x > max, max, x))
return clip
@jax_funcify.register(Composite)
def jax_funcify_Composite(op, vectorize=True, **kwargs):
jax_impl = jax_funcify(op.fgraph)
def composite(*args):
return jax_impl(*args)[0]
return jnp.vectorize(composite)
@jax_funcify.register(Second)
def jax_funcify_Second(op, **kwargs):
def second(x, y):
return jnp.broadcast_to(y, x.shape)
return second
@jax_funcify.register(Erf)
def jax_funcify_Erf(op, node, **kwargs):
def erf(x):
return jax.scipy.special.erf(x)
return erf
@jax_funcify.register(Erfc)
def jax_funcify_Erfc(op, **kwargs):
def erfc(x):
return jax.scipy.special.erfc(x)
return erfc
@jax_funcify.register(Erfinv)
def jax_funcify_Erfinv(op, **kwargs):
def erfinv(x):
return jax.scipy.special.erfinv(x)
return erfinv
@jax_funcify.register(Log1mexp)
def jax_funcify_Log1mexp(op, node, **kwargs):
def log1mexp(x):
return jnp.where(
x < jnp.log(0.5), jnp.log1p(-jnp.exp(x)), jnp.log(-jnp.expm1(x))
)
return log1mexp
@jax_funcify.register(Psi)
def jax_funcify_Psi(op, node, **kwargs):
def psi(x):
return jax.scipy.special.digamma(x)
return psi
@jax_funcify.register(Softplus)
def jax_funcify_Softplus(op, **kwargs):
def softplus(x):
# This expression is numerically equivalent to the Aesara one
# It just contains one "speed" optimization less than the Aesara counterpart
return jnp.where(
x < -37.0, jnp.exp(x), jnp.where(x > 33.3, x, jnp.log1p(jnp.exp(x)))
)
return softplus
import jax
import jax.numpy as jnp
from aesara.graph.fg import FunctionGraph
from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.scan.op import Scan
from aesara.scan.utils import ScanArgs
@jax_funcify.register(Scan)
def jax_funcify_Scan(op, **kwargs):
inner_fg = FunctionGraph(op.inputs, op.outputs)
jax_at_inner_func = jax_funcify(inner_fg, **kwargs)
def scan(*outer_inputs):
scan_args = ScanArgs(
list(outer_inputs), [None] * op.info.n_outs, op.inputs, op.outputs, op.info
)
# `outer_inputs` is a list with the following composite form:
# [n_steps]
# + outer_in_seqs
# + outer_in_mit_mot
# + outer_in_mit_sot
# + outer_in_sit_sot
# + outer_in_shared
# + outer_in_nit_sot
# + outer_in_non_seqs
n_steps = scan_args.n_steps
seqs = scan_args.outer_in_seqs
# TODO: mit_mots
mit_mot_in_slices = []
mit_sot_in_slices = []
for tap, seq in zip(scan_args.mit_sot_in_slices, scan_args.outer_in_mit_sot):
neg_taps = [abs(t) for t in tap if t < 0]
pos_taps = [abs(t) for t in tap if t > 0]
max_neg = max(neg_taps) if neg_taps else 0
max_pos = max(pos_taps) if pos_taps else 0
init_slice = seq[: max_neg + max_pos]
mit_sot_in_slices.append(init_slice)
sit_sot_in_slices = [seq[0] for seq in scan_args.outer_in_sit_sot]
init_carry = (
mit_mot_in_slices,
mit_sot_in_slices,
sit_sot_in_slices,
scan_args.outer_in_shared,
scan_args.outer_in_non_seqs,
)
def jax_args_to_inner_scan(op, carry, x):
# `carry` contains all inner-output taps, non_seqs, and shared
# terms
(
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
) = carry
# `x` contains the in_seqs
inner_in_seqs = x
# `inner_scan_inputs` is a list with the following composite form:
# inner_in_seqs
# + sum(inner_in_mit_mot, [])
# + sum(inner_in_mit_sot, [])
# + inner_in_sit_sot
# + inner_in_shared
# + inner_in_non_seqs
inner_in_mit_sot_flatten = []
for array, index in zip(inner_in_mit_sot, scan_args.mit_sot_in_slices):
inner_in_mit_sot_flatten.extend(array[jnp.array(index)])
inner_scan_inputs = sum(
[
inner_in_seqs,
inner_in_mit_mot,
inner_in_mit_sot_flatten,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
],
[],
)
return inner_scan_inputs
def inner_scan_outs_to_jax_outs(
op,
old_carry,
inner_scan_outs,
):
(
inner_in_mit_mot,
inner_in_mit_sot,
inner_in_sit_sot,
inner_in_shared,
inner_in_non_seqs,
) = old_carry
def update_mit_sot(mit_sot, new_val):
return jnp.concatenate([mit_sot[1:], new_val[None, ...]], axis=0)
inner_out_mit_sot = [
update_mit_sot(mit_sot, new_val)
for mit_sot, new_val in zip(inner_in_mit_sot, inner_scan_outs)
]
# This should contain all inner-output taps, non_seqs, and shared
# terms
if not inner_in_sit_sot:
inner_out_sit_sot = []
else:
inner_out_sit_sot = inner_scan_outs
new_carry = (
inner_in_mit_mot,
inner_out_mit_sot,
inner_out_sit_sot,
inner_in_shared,
inner_in_non_seqs,
)
return new_carry
def jax_inner_func(carry, x):
inner_args = jax_args_to_inner_scan(op, carry, x)
inner_scan_outs = list(jax_at_inner_func(*inner_args))
new_carry = inner_scan_outs_to_jax_outs(op, carry, inner_scan_outs)
return new_carry, inner_scan_outs
_, scan_out = jax.lax.scan(jax_inner_func, init_carry, seqs, length=n_steps)
# We need to prepend the initial values so that the JAX output will
# match the raw `Scan` `Op` output and, thus, work with a downstream
# `Subtensor` `Op` introduced by the `scan` helper function.
def append_scan_out(scan_in_part, scan_out_part):
return jnp.concatenate([scan_in_part[:-n_steps], scan_out_part], axis=0)
if scan_args.outer_in_mit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_mit_sot, scan_out)
]
elif scan_args.outer_in_sit_sot:
scan_out_final = [
append_scan_out(init, out)
for init, out in zip(scan_args.outer_in_sit_sot, scan_out)
]
if len(scan_out_final) == 1:
scan_out_final = scan_out_final[0]
return scan_out_final
return scan
import jax.numpy as jnp
from aesara.graph import Constant
from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape, Unbroadcast
@jax_funcify.register(Reshape)
def jax_funcify_Reshape(op, node, **kwargs):
# JAX reshape only works with constant inputs, otherwise JIT fails
shape = node.inputs[1]
if isinstance(shape, Constant):
constant_shape = shape.data
def reshape(x, shape):
return jnp.reshape(x, constant_shape)
else:
def reshape(x, shape):
return jnp.reshape(x, shape)
return reshape
@jax_funcify.register(Shape)
def jax_funcify_Shape(op, **kwargs):
def shape(x):
return jnp.shape(x)
return shape
@jax_funcify.register(Shape_i)
def jax_funcify_Shape_i(op, **kwargs):
i = op.i
def shape_i(x):
return jnp.shape(x)[i]
return shape_i
@jax_funcify.register(SpecifyShape)
def jax_funcify_SpecifyShape(op, **kwargs):
def specifyshape(x, *shape):
assert x.ndim == len(shape)
assert jnp.all(x.shape == tuple(shape)), (
"got shape",
x.shape,
"expected",
shape,
)
return x
return specifyshape
@jax_funcify.register(Unbroadcast)
def jax_funcify_Unbroadcast(op, **kwargs):
def unbroadcast(x):
return x
return unbroadcast
import jax
from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.tensor.slinalg import Cholesky, Solve, SolveTriangular
@jax_funcify.register(Cholesky)
def jax_funcify_Cholesky(op, **kwargs):
lower = op.lower
def cholesky(a, lower=lower):
return jax.scipy.linalg.cholesky(a, lower=lower).astype(a.dtype)
return cholesky
@jax_funcify.register(Solve)
def jax_funcify_Solve(op, **kwargs):
if op.assume_a != "gen" and op.lower:
lower = True
else:
lower = False
def solve(a, b, lower=lower):
return jax.scipy.linalg.solve(a, b, lower=lower)
return solve
@jax_funcify.register(SolveTriangular)
def jax_funcify_SolveTriangular(op, **kwargs):
lower = op.lower
trans = op.trans
unit_diagonal = op.unit_diagonal
check_finite = op.check_finite
def solve_triangular(A, b):
return jax.scipy.linalg.solve_triangular(
A,
b,
lower=lower,
trans=trans,
unit_diagonal=unit_diagonal,
check_finite=check_finite,
)
return solve_triangular
import jax
from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
AdvancedSubtensor,
AdvancedSubtensor1,
IncSubtensor,
Subtensor,
indices_from_subtensor,
)
from aesara.tensor.type_other import MakeSlice
@jax_funcify.register(Subtensor)
@jax_funcify.register(AdvancedSubtensor)
@jax_funcify.register(AdvancedSubtensor1)
def jax_funcify_Subtensor(op, **kwargs):
idx_list = getattr(op, "idx_list", None)
def subtensor(x, *ilists):
indices = indices_from_subtensor(ilists, idx_list)
if len(indices) == 1:
indices = indices[0]
return x.__getitem__(indices)
return subtensor
@jax_funcify.register(IncSubtensor)
@jax_funcify.register(AdvancedIncSubtensor1)
def jax_funcify_IncSubtensor(op, **kwargs):
idx_list = getattr(op, "idx_list", None)
if getattr(op, "set_instead_of_inc", False):
jax_fn = getattr(jax.ops, "index_update", None)
if jax_fn is None:
def jax_fn(x, indices, y):
return x.at[indices].set(y)
else:
jax_fn = getattr(jax.ops, "index_add", None)
if jax_fn is None:
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def incsubtensor(x, y, *ilist, jax_fn=jax_fn, idx_list=idx_list):
indices = indices_from_subtensor(ilist, idx_list)
if len(indices) == 1:
indices = indices[0]
return jax_fn(x, indices, y)
return incsubtensor
@jax_funcify.register(AdvancedIncSubtensor)
def jax_funcify_AdvancedIncSubtensor(op, **kwargs):
if getattr(op, "set_instead_of_inc", False):
jax_fn = getattr(jax.ops, "index_update", None)
if jax_fn is None:
def jax_fn(x, indices, y):
return x.at[indices].set(y)
else:
jax_fn = getattr(jax.ops, "index_add", None)
if jax_fn is None:
def jax_fn(x, indices, y):
return x.at[indices].add(y)
def advancedincsubtensor(x, y, *ilist, jax_fn=jax_fn):
return jax_fn(x, ilist, y)
return advancedincsubtensor
@jax_funcify.register(MakeSlice)
def jax_funcify_MakeSlice(op, **kwargs):
def makeslice(*x):
return slice(*x)
return makeslice
import jax.numpy as jnp
from aesara.link.jax.dispatch.basic import jax_funcify
from aesara.tensor.basic import (
Alloc,
AllocDiag,
AllocEmpty,
ARange,
ExtractDiag,
Eye,
Join,
MakeVector,
ScalarFromTensor,
TensorFromScalar,
)
@jax_funcify.register(AllocDiag)
def jax_funcify_AllocDiag(op, **kwargs):
offset = op.offset
def allocdiag(v, offset=offset):
return jnp.diag(v, k=offset)
return allocdiag
@jax_funcify.register(AllocEmpty)
def jax_funcify_AllocEmpty(op, **kwargs):
def allocempty(*shape):
return jnp.empty(shape, dtype=op.dtype)
return allocempty
@jax_funcify.register(Alloc)
def jax_funcify_Alloc(op, **kwargs):
def alloc(x, *shape):
res = jnp.broadcast_to(x, shape)
return res
return alloc
@jax_funcify.register(ARange)
def jax_funcify_ARange(op, **kwargs):
# XXX: This currently requires concrete arguments.
def arange(start, stop, step):
return jnp.arange(start, stop, step, dtype=op.dtype)
return arange
@jax_funcify.register(Join)
def jax_funcify_Join(op, **kwargs):
def join(axis, *tensors):
# tensors could also be tuples, and in this case they don't have a ndim
tensors = [jnp.asarray(tensor) for tensor in tensors]
view = op.view
if (view != -1) and all(
tensor.shape[axis] == 0 for tensor in tensors[0:view] + tensors[view + 1 :]
):
return tensors[view]
else:
return jnp.concatenate(tensors, axis=axis)
return join
@jax_funcify.register(ExtractDiag)
def jax_funcify_ExtractDiag(op, **kwargs):
offset = op.offset
axis1 = op.axis1
axis2 = op.axis2
def extract_diag(x, offset=offset, axis1=axis1, axis2=axis2):
return jnp.diagonal(x, offset=offset, axis1=axis1, axis2=axis2)
return extract_diag
@jax_funcify.register(Eye)
def jax_funcify_Eye(op, **kwargs):
dtype = op.dtype
def eye(N, M, k):
return jnp.eye(N, M, k, dtype=dtype)
return eye
@jax_funcify.register(MakeVector)
def jax_funcify_MakeVector(op, **kwargs):
def makevector(*x):
return jnp.array(x, dtype=op.dtype)
return makevector
@jax_funcify.register(TensorFromScalar)
def jax_funcify_TensorFromScalar(op, **kwargs):
def tensor_from_scalar(x):
return jnp.array(x)
return tensor_from_scalar
@jax_funcify.register(ScalarFromTensor)
def jax_funcify_ScalarFromTensor(op, **kwargs):
def scalar_from_tensor(x):
return jnp.array(x).flatten()[0]
return scalar_from_tensor
import jax
import numpy as np
import pytest
from jax._src.errors import NonConcreteBooleanIndexError
from packaging.version import parse as version_parse
import aesara.tensor as at
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.tensor import subtensor as at_subtensor
from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_Subtensors():
# Basic indices
x_at = at.as_tensor(np.arange(3 * 4 * 5).reshape((3, 4, 5)))
out_at = x_at[1, 2, 0]
assert isinstance(out_at.owner.op, at_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
out_at = x_at[1:2, 1, :]
assert isinstance(out_at.owner.op, at_subtensor.Subtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
# Advanced indexing
out_at = at_subtensor.advanced_subtensor1(x_at, [1, 2])
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor1)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
out_at = x_at[[1, 2], [2, 3]]
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
# Advanced and basic indexing
out_at = x_at[[1, 2], :]
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
out_at = x_at[[1, 2], :, [3, 4]]
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_Subtensors_omni():
x_at = at.arange(3 * 4 * 5).reshape((3, 4, 5))
# Boolean indices
out_at = x_at[x_at < 0]
assert isinstance(out_at.owner.op, at_subtensor.AdvancedSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
def test_jax_IncSubtensor():
rng = np.random.default_rng(213234)
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
# "Set" basic indices
st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_at = at_subtensor.set_subtensor(x_at[1, 2, 3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_at = at_subtensor.set_subtensor(x_at[:2, 0, 0], st_at)
assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
out_at = at_subtensor.set_subtensor(x_at[0, 1:3, 0], st_at)
assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
# "Set" advanced indices
st_at = at.as_tensor_variable(
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_at = at_subtensor.set_subtensor(x_at[np.r_[0, 2]], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, 0], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
# "Set" boolean indices
mask_at = at.constant(x_np > 0)
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
# "Increment" basic indices
st_at = at.as_tensor_variable(np.array(-10.0, dtype=config.floatX))
out_at = at_subtensor.inc_subtensor(x_at[1, 2, 3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_at = at_subtensor.inc_subtensor(x_at[:2, 0, 0], st_at)
assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
out_at = at_subtensor.set_subtensor(x_at[0, 1:3, 0], st_at)
assert isinstance(out_at.owner.op, at_subtensor.IncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
# "Increment" advanced indices
st_at = at.as_tensor_variable(
rng.uniform(-1, 1, size=(2, 4, 5)).astype(config.floatX)
)
out_at = at_subtensor.inc_subtensor(x_at[np.r_[0, 2]], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(np.r_[-1.0, 0.0].astype(config.floatX))
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, 0], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
# "Increment" boolean indices
mask_at = at.constant(x_np > 0)
out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
compare_jax_and_py(out_fg, [])
def test_jax_IncSubtensors_unsupported():
rng = np.random.default_rng(213234)
x_np = rng.uniform(-1, 1, size=(3, 4, 5)).astype(config.floatX)
x_at = at.constant(np.arange(3 * 4 * 5).reshape((3, 4, 5)).astype(config.floatX))
mask_at = at.as_tensor(x_np) > 0
out_at = at_subtensor.set_subtensor(x_at[mask_at], 0.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
):
compare_jax_and_py(out_fg, [])
mask_at = at.as_tensor_variable(x_np) > 0
out_at = at_subtensor.set_subtensor(x_at[mask_at], 1.0)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(
NonConcreteBooleanIndexError, match="Array boolean indices must be concrete"
):
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.set_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(IndexError, match="Array slice indices must have static"):
compare_jax_and_py(out_fg, [])
st_at = at.as_tensor_variable(x_np[[0, 2], 0, :3])
out_at = at_subtensor.inc_subtensor(x_at[[0, 2], 0, :3], st_at)
assert isinstance(out_at.owner.op, at_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([], [out_at])
with pytest.raises(IndexError, match="Array slice indices must have static"):
compare_jax_and_py(out_fg, [])
...@@ -3440,7 +3440,7 @@ def profile_printer( ...@@ -3440,7 +3440,7 @@ def profile_printer(
) )
@op_debug_information.register(Scan) # type: ignore @op_debug_information.register(Scan)
def _op_debug_information_Scan(op, node): def _op_debug_information_Scan(op, node):
from typing import Sequence from typing import Sequence
......
from functools import partial
from typing import Callable, Iterable, Optional
import numpy as np
import pytest
from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.compile.sharedvalue import SharedVariable, shared
from aesara.configdefaults import config
from aesara.graph.basic import Apply
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op, get_test_value
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.ifelse import ifelse
from aesara.link.jax import JAXLinker
from aesara.raise_op import assert_op
from aesara.tensor.type import dscalar, scalar, vector
@pytest.fixture(scope="module", autouse=True)
def set_aesara_flags():
with config.change_flags(cxx="", compute_test_value="ignore"):
yield
jax = pytest.importorskip("jax")
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
py_mode = Mode("py", opts)
def compare_jax_and_py(
fgraph: FunctionGraph,
test_inputs: Iterable,
assert_fn: Optional[Callable] = None,
must_be_device_array: bool = True,
):
"""Function to compare python graph output and jax compiled output for testing equality
In the tests below computational graphs are defined in Aesara. These graphs are then passed to
this function which then compiles the graphs in both jax and python, runs the calculation
in both and checks if the results are the same
Parameters
----------
fgraph: FunctionGraph
Aesara function Graph object
test_inputs: iter
Numerical inputs for testing the function graph
assert_fn: func, opt
Assert function used to check for equality between python and jax. If not
provided uses np.testing.assert_allclose
must_be_device_array: Bool
Checks for instance of jax.interpreters.xla.DeviceArray. For testing purposes
if this device array is found it indicates if the result was computed by jax
Returns
-------
jax_res
"""
if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
aesara_jax_fn = function(fn_inputs, fgraph.outputs, mode=jax_mode)
jax_res = aesara_jax_fn(*test_inputs)
if must_be_device_array:
if isinstance(jax_res, list):
assert all(
isinstance(res, jax.interpreters.xla.DeviceArray) for res in jax_res
)
else:
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
aesara_py_fn = function(fn_inputs, fgraph.outputs, mode=py_mode)
py_res = aesara_py_fn(*test_inputs)
if len(fgraph.outputs) > 1:
for j, p in zip(jax_res, py_res):
assert_fn(j, p)
else:
assert_fn(jax_res, py_res)
return jax_res
def test_jax_FunctionGraph_names():
import inspect
from aesara.link.jax.dispatch import jax_funcify
x = scalar("1x")
y = scalar("_")
z = scalar()
q = scalar("def")
out_fg = FunctionGraph([x, y, z, q], [x, y, z, q], clone=False)
out_jx = jax_funcify(out_fg)
sig = inspect.signature(out_jx)
assert (x.auto_name, "_", z.auto_name, q.auto_name) == tuple(sig.parameters.keys())
assert (1, 2, 3, 4) == out_jx(1, 2, 3, 4)
def test_jax_FunctionGraph_once():
"""Make sure that an output is only computed once when it's referenced multiple times."""
from aesara.link.jax.dispatch import jax_funcify
x = vector("x")
y = vector("y")
class TestOp(Op):
def __init__(self):
self.called = 0
def make_node(self, *args):
return Apply(self, list(args), [x.type() for x in args])
def perform(self, inputs, outputs):
for i, inp in enumerate(inputs):
outputs[i][0] = inp[0]
@jax_funcify.register(TestOp)
def jax_funcify_TestOp(op, **kwargs):
def func(*args, op=op):
op.called += 1
return list(args)
return func
op1 = TestOp()
op2 = TestOp()
q, r = op1(x, y)
outs = op2(q + r, q + r)
out_fg = FunctionGraph([x, y], outs, clone=False)
assert len(out_fg.outputs) == 2
out_jx = jax_funcify(out_fg)
x_val = np.r_[1, 2].astype(config.floatX)
y_val = np.r_[2, 3].astype(config.floatX)
res = out_jx(x_val, y_val)
assert len(res) == 2
assert op1.called == 1
assert op2.called == 1
res = out_jx(x_val, y_val)
assert len(res) == 2
assert op1.called == 2
assert op2.called == 2
def test_shared():
a = shared(np.array([1, 2, 3], dtype=config.floatX))
aesara_jax_fn = function([], a, mode="JAX")
jax_res = aesara_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
np.testing.assert_allclose(jax_res, a.get_value())
aesara_jax_fn = function([], a * 2, mode="JAX")
jax_res = aesara_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
np.testing.assert_allclose(jax_res, a.get_value() * 2)
# Changed the shared value and make sure that the JAX-compiled
# function also changes.
new_a_value = np.array([3, 4, 5], dtype=config.floatX)
a.set_value(new_a_value)
jax_res = aesara_jax_fn()
assert isinstance(jax_res, jax.interpreters.xla.DeviceArray)
np.testing.assert_allclose(jax_res, new_a_value * 2)
def test_jax_ifelse():
true_vals = np.r_[1, 2, 3]
false_vals = np.r_[-1, -2, -3]
x = ifelse(np.array(True), true_vals, false_vals)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
a = dscalar("a")
a.tag.test_value = np.array(0.2, dtype=config.floatX)
x = ifelse(a < 0.5, true_vals, false_vals)
x_fg = FunctionGraph([a], [x]) # I.e. False
compare_jax_and_py(x_fg, [get_test_value(i) for i in x_fg.inputs])
def test_jax_checkandraise():
p = scalar()
p.tag.test_value = 0
res = assert_op(p, p < 1.0)
res_fg = FunctionGraph([p], [res])
with pytest.raises(NotImplementedError):
compare_jax_and_py(res_fg, [1.0])
import numpy as np
import pytest
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.tensor import elemwise as at_elemwise
from aesara.tensor import nnet as at_nnet
from aesara.tensor.math import all as at_all
from aesara.tensor.math import prod
from aesara.tensor.math import sum as at_sum
from aesara.tensor.nnet.basic import SoftmaxGrad
from aesara.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_Dimshuffle():
a_at = matrix("a")
x = a_at.T
x_fg = FunctionGraph([a_at], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
x = a_at.dimshuffle([0, 1, "x"])
x_fg = FunctionGraph([a_at], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX)])
a_at = tensor(dtype=config.floatX, shape=[False, True])
x = a_at.dimshuffle((0,))
x_fg = FunctionGraph([a_at], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
a_at = tensor(dtype=config.floatX, shape=[False, True])
x = at_elemwise.DimShuffle([False, True], (0,))(a_at)
x_fg = FunctionGraph([a_at], [x])
compare_jax_and_py(x_fg, [np.c_[[1.0, 2.0, 3.0, 4.0]].astype(config.floatX)])
def test_jax_CAReduce():
a_at = vector("a")
a_at.tag.test_value = np.r_[1, 2, 3].astype(config.floatX)
x = at_sum(a_at, axis=None)
x_fg = FunctionGraph([a_at], [x])
compare_jax_and_py(x_fg, [np.r_[1, 2, 3].astype(config.floatX)])
a_at = matrix("a")
a_at.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = at_sum(a_at, axis=0)
x_fg = FunctionGraph([a_at], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = at_sum(a_at, axis=1)
x_fg = FunctionGraph([a_at], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
a_at = matrix("a")
a_at.tag.test_value = np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)
x = prod(a_at, axis=0)
x_fg = FunctionGraph([a_at], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
x = at_all(a_at)
x_fg = FunctionGraph([a_at], [x])
compare_jax_and_py(x_fg, [np.c_[[1, 2, 3], [1, 2, 3]].astype(config.floatX)])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = at_nnet.softmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_logsoftmax(axis):
x = matrix("x")
x.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = at_nnet.logsoftmax(x, axis=axis)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize("axis", [None, 0, 1])
def test_softmax_grad(axis):
dy = matrix("dy")
dy.tag.test_value = np.array([[1, 1, 1], [0, 0, 0]], dtype=config.floatX)
sm = matrix("sm")
sm.tag.test_value = np.arange(6, dtype=config.floatX).reshape(2, 3)
out = SoftmaxGrad(axis=axis)(dy, sm)
fgraph = FunctionGraph([dy, sm], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
import numpy as np
import pytest
from packaging.version import parse as version_parse
import aesara.tensor.basic as at
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.tensor import extra_ops as at_extra_ops
from aesara.tensor.type import matrix, vector
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
def set_test_value(x, v):
x.tag.test_value = v
return x
def test_extra_ops():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = at_extra_ops.cumsum(a, axis=0)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = at_extra_ops.cumprod(a, axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = at_extra_ops.diff(a, n=2, axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = at_extra_ops.repeat(a, (3, 3), axis=1)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
c = at.as_tensor(5)
out = at_extra_ops.fill_diagonal(a, c)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = at_extra_ops.fill_diagonal_offset(a, c, c)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
with pytest.raises(NotImplementedError):
out = at_extra_ops.Unique(axis=1)(a)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
indices = np.arange(np.product((3, 4)))
out = at_extra_ops.unravel_index(indices, (3, 4), order="C")
fgraph = FunctionGraph([], out)
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
@pytest.mark.parametrize(
"x, shape",
[
(
set_test_value(
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
),
[at.as_tensor(3, dtype=np.int64), at.as_tensor(2, dtype=np.int64)],
),
(
set_test_value(
vector("x"), np.random.random(size=(2,)).astype(config.floatX)
),
[at.as_tensor(3, dtype=np.int8), at.as_tensor(2, dtype=np.int64)],
),
],
)
def test_BroadcastTo(x, shape):
out = at_extra_ops.broadcast_to(x, shape)
fgraph = FunctionGraph(outputs=[out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_extra_ops_omni():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
# This function also cannot take symbolic input.
c = at.as_tensor(5)
out = at_extra_ops.bartlett(c)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
multi_index = np.unravel_index(np.arange(np.product((3, 4))), (3, 4))
out = at_extra_ops.ravel_multi_index(multi_index, (3, 4))
fgraph = FunctionGraph([], [out])
compare_jax_and_py(
fgraph, [get_test_value(i) for i in fgraph.inputs], must_be_device_array=False
)
# The inputs are "concrete", yet it still has problems?
out = at_extra_ops.Unique()(
at.as_tensor(np.arange(6, dtype=config.floatX).reshape((3, 2)))
)
fgraph = FunctionGraph([], [out])
compare_jax_and_py(fgraph, [])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_unique_nonconcrete():
a = matrix("a")
a.tag.test_value = np.arange(6, dtype=config.floatX).reshape((3, 2))
out = at_extra_ops.Unique()(a)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
import numpy as np
import pytest
from packaging.version import parse as version_parse
from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.graph.rewriting.db import RewriteDatabaseQuery
from aesara.link.jax import JAXLinker
from aesara.tensor import blas as at_blas
from aesara.tensor import nlinalg as at_nlinalg
from aesara.tensor.math import MaxAndArgmax
from aesara.tensor.math import max as at_max
from aesara.tensor.math import maximum
from aesara.tensor.type import dvector, matrix, scalar, tensor3, vector
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
def test_jax_BatchedDot():
# tensor3 . tensor3
a = tensor3("a")
a.tag.test_value = (
np.linspace(-1, 1, 10 * 5 * 3).astype(config.floatX).reshape((10, 5, 3))
)
b = tensor3("b")
b.tag.test_value = (
np.linspace(1, -1, 10 * 3 * 2).astype(config.floatX).reshape((10, 3, 2))
)
out = at_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
# A dimension mismatch should raise a TypeError for compatibility
inputs = [get_test_value(a)[:-1], get_test_value(b)]
opts = RewriteDatabaseQuery(include=[None], exclude=["cxx_only", "BlasOpt"])
jax_mode = Mode(JAXLinker(), opts)
aesara_jax_fn = function(fgraph.inputs, fgraph.outputs, mode=jax_mode)
with pytest.raises(TypeError):
aesara_jax_fn(*inputs)
# matrix . matrix
a = matrix("a")
a.tag.test_value = np.linspace(-1, 1, 5 * 3).astype(config.floatX).reshape((5, 3))
b = matrix("b")
b.tag.test_value = np.linspace(1, -1, 5 * 3).astype(config.floatX).reshape((5, 3))
out = at_blas.BatchedDot()(a, b)
fgraph = FunctionGraph([a, b], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_basic_multiout():
rng = np.random.default_rng(213234)
M = rng.normal(size=(3, 3))
X = M.dot(M.T)
x = matrix("x")
outs = at_nlinalg.eig(x)
out_fg = FunctionGraph([x], outs)
def assert_fn(x, y):
np.testing.assert_allclose(x.astype(config.floatX), y, rtol=1e-3)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = at_nlinalg.eigh(x)
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = at_nlinalg.qr(x, mode="full")
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = at_nlinalg.qr(x, mode="reduced")
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
outs = at_nlinalg.svd(x)
out_fg = FunctionGraph([x], outs)
compare_jax_and_py(out_fg, [X.astype(config.floatX)], assert_fn=assert_fn)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_basic_multiout_omni():
# Test that a single output of a multi-output `Op` can be used as input to
# another `Op`
x = dvector()
mx, amx = MaxAndArgmax([0])(x)
out = mx * amx
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(out_fg, [np.r_[1, 2]])
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_tensor_basics():
y = vector("y")
y.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
x = vector("x")
x.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
A = matrix("A")
A.tag.test_value = np.empty((2, 2), dtype=config.floatX)
alpha = scalar("alpha")
alpha.tag.test_value = np.array(3.0, dtype=config.floatX)
beta = scalar("beta")
beta.tag.test_value = np.array(5.0, dtype=config.floatX)
# This should be converted into a `Gemv` `Op` when the non-JAX compatible
# optimizations are turned on; however, when using JAX mode, it should
# leave the expression alone.
out = y.dot(alpha * A).dot(x) + beta * y
fgraph = FunctionGraph([y, x, A, alpha, beta], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = maximum(y, x)
fgraph = FunctionGraph([y, x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = at_max(y)
fgraph = FunctionGraph([y], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
import numpy as np
import pytest
from packaging.version import parse as version_parse
import aesara.tensor as at
from aesara.compile.function import function
from aesara.compile.sharedvalue import shared
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.basic import RandomVariable
from aesara.tensor.random.utils import RandomStream
from tests.link.jax.test_basic import compare_jax_and_py, jax_mode
jax = pytest.importorskip("jax")
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.26"),
reason="JAX samplers require concrete/static shape values?",
)
@pytest.mark.parametrize(
"at_dist, dist_params, rng, size",
[
(
at.random.normal,
(),
shared(np.random.RandomState(123)),
10000,
),
(
at.random.normal,
(),
shared(np.random.default_rng(123)),
10000,
),
],
)
def test_random_stats(at_dist, dist_params, rng, size):
# The RNG states are not 1:1, so the best we can do is check some summary
# statistics of the samples
out = at.random.normal(*dist_params, rng=rng, size=size)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
def assert_fn(x, y):
(x,) = x
(y,) = y
assert x.dtype.kind == y.dtype.kind
d = 2 if config.floatX == "float64" else 1
np.testing.assert_array_almost_equal(np.abs(x.mean()), np.abs(y.mean()), d)
compare_jax_and_py(fgraph, [], assert_fn=assert_fn)
def test_random_unimplemented():
class NonExistentRV(RandomVariable):
name = "non-existent"
ndim_supp = 0
ndims_params = []
dtype = "floatX"
def __call__(self, size=None, **kwargs):
return super().__call__(size=size, **kwargs)
def rng_fn(cls, rng, size):
return 0
nonexistentrv = NonExistentRV()
rng = shared(np.random.RandomState(123))
out = nonexistentrv(rng=rng)
fgraph = FunctionGraph([out.owner.inputs[0]], [out], clone=False)
with pytest.raises(NotImplementedError):
compare_jax_and_py(fgraph, [])
def test_RandomStream():
srng = RandomStream(seed=123)
out = srng.normal() - srng.normal()
fn = function([], out, mode=jax_mode)
jax_res_1 = fn()
jax_res_2 = fn()
assert np.array_equal(jax_res_1, jax_res_2)
import numpy as np
import pytest
import aesara.scalar.basic as aes
import aesara.tensor as at
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.scalar.basic import Composite
from aesara.tensor import nnet as at_nnet
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.math import all as at_all
from aesara.tensor.math import (
cosh,
erf,
erfc,
erfinv,
log,
log1mexp,
psi,
sigmoid,
softplus,
)
from aesara.tensor.type import matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
def test_second():
a0 = scalar("a0")
b = scalar("b")
out = aes.second(a0, b)
fgraph = FunctionGraph([a0, b], [out])
compare_jax_and_py(fgraph, [10.0, 5.0])
a1 = vector("a1")
out = at.second(a1, b)
fgraph = FunctionGraph([a1, b], [out])
compare_jax_and_py(fgraph, [np.zeros([5], dtype=config.floatX), 5.0])
def test_identity():
a = scalar("a")
a.tag.test_value = 10
out = aes.identity(a)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
@pytest.mark.parametrize(
"x, y, x_val, y_val",
[
(scalar("x"), scalar("y"), np.array(10), np.array(20)),
(scalar("x"), vector("y"), np.array(10), np.arange(10, 20)),
(
matrix("x"),
vector("y"),
np.arange(10 * 20).reshape((20, 10)),
np.arange(10, 20),
),
],
)
def test_jax_Composite(x, y, x_val, y_val):
x_s = aes.float64("x")
y_s = aes.float64("y")
comp_op = Elemwise(Composite([x_s, y_s], [x_s + y_s * 2 + aes.exp(x_s - y_s)]))
out = comp_op(x, y)
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [
x_val.astype(config.floatX),
y_val.astype(config.floatX),
]
_ = compare_jax_and_py(out_fg, test_input_vals)
def test_erf():
x = scalar("x")
out = erf(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0])
def test_erfc():
x = scalar("x")
out = erfc(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0])
def test_erfinv():
x = scalar("x")
out = erfinv(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [1.0])
def test_psi():
x = scalar("x")
out = psi(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [3.0])
def test_log1mexp():
x = vector("x")
out = log1mexp(x)
fg = FunctionGraph([x], [out])
compare_jax_and_py(fg, [[-1.0, -0.75, -0.5, -0.25]])
def test_nnet():
x = vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
out = sigmoid(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = at_nnet.ultra_fast_sigmoid(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
out = softplus(x)
fgraph = FunctionGraph([x], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_variadic_Scalar():
mu = vector("mu", dtype=config.floatX)
mu.tag.test_value = np.r_[0.1, 1.1].astype(config.floatX)
tau = vector("tau", dtype=config.floatX)
tau.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
res = -tau * mu
fgraph = FunctionGraph([mu, tau], [res])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
res = -tau * (tau - mu) ** 2
fgraph = FunctionGraph([mu, tau], [res])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_multioutput():
x = vector("x")
x.tag.test_value = np.r_[1.0, 2.0].astype(config.floatX)
y = vector("y")
y.tag.test_value = np.r_[3.0, 4.0].astype(config.floatX)
w = cosh(x**2 + y / 3.0)
v = cosh(x / 3.0 + y**2)
fgraph = FunctionGraph([x, y], [w, v])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_logp():
mu = vector("mu")
mu.tag.test_value = np.r_[0.0, 0.0].astype(config.floatX)
tau = vector("tau")
tau.tag.test_value = np.r_[1.0, 1.0].astype(config.floatX)
sigma = vector("sigma")
sigma.tag.test_value = (1.0 / get_test_value(tau)).astype(config.floatX)
value = vector("value")
value.tag.test_value = np.r_[0.1, -10].astype(config.floatX)
logp = (-tau * (value - mu) ** 2 + log(tau / np.pi / 2.0)) / 2.0
conditions = [sigma > 0]
alltrue = at_all([at_all(1 * val) for val in conditions])
normal_logp = at.switch(alltrue, logp, -np.inf)
fgraph = FunctionGraph([mu, tau, sigma, value], [normal_logp])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
import numpy as np
import pytest
from packaging.version import parse as version_parse
import aesara.tensor as at
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.scan.basic import scan
from aesara.tensor.math import gammaln, log
from aesara.tensor.type import ivector, lscalar, scalar
from tests.link.jax.test_basic import compare_jax_and_py
jax = pytest.importorskip("jax")
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_scan_multiple_output():
"""Test a scan implementation of a SEIR model.
SEIR model definition:
S[t+1] = S[t] - B[t]
E[t+1] = E[t] +B[t] - C[t]
I[t+1] = I[t+1] + C[t] - D[t]
B[t] ~ Binom(S[t], beta)
C[t] ~ Binom(E[t], gamma)
D[t] ~ Binom(I[t], delta)
"""
def binomln(n, k):
return gammaln(n + 1) - gammaln(k + 1) - gammaln(n - k + 1)
def binom_log_prob(n, p, value):
return binomln(n, value) + value * log(p) + (n - value) * log(1 - p)
# sequences
at_C = ivector("C_t")
at_D = ivector("D_t")
# outputs_info (initial conditions)
st0 = lscalar("s_t0")
et0 = lscalar("e_t0")
it0 = lscalar("i_t0")
logp_c = scalar("logp_c")
logp_d = scalar("logp_d")
# non_sequences
beta = scalar("beta")
gamma = scalar("gamma")
delta = scalar("delta")
# TODO: Use random streams when their JAX conversions are implemented.
# trng = aesara.tensor.random.RandomStream(1234)
def seir_one_step(ct0, dt0, st0, et0, it0, logp_c, logp_d, beta, gamma, delta):
# bt0 = trng.binomial(n=st0, p=beta)
bt0 = st0 * beta
bt0 = bt0.astype(st0.dtype)
logp_c1 = binom_log_prob(et0, gamma, ct0).astype(logp_c.dtype)
logp_d1 = binom_log_prob(it0, delta, dt0).astype(logp_d.dtype)
st1 = st0 - bt0
et1 = et0 + bt0 - ct0
it1 = it0 + ct0 - dt0
return st1, et1, it1, logp_c1, logp_d1
(st, et, it, logp_c_all, logp_d_all), _ = scan(
fn=seir_one_step,
sequences=[at_C, at_D],
outputs_info=[st0, et0, it0, logp_c, logp_d],
non_sequences=[beta, gamma, delta],
)
st.name = "S_t"
et.name = "E_t"
it.name = "I_t"
logp_c_all.name = "C_t_logp"
logp_d_all.name = "D_t_logp"
out_fg = FunctionGraph(
[at_C, at_D, st0, et0, it0, logp_c, logp_d, beta, gamma, delta],
[st, et, it, logp_c_all, logp_d_all],
)
s0, e0, i0 = 100, 50, 25
logp_c0 = np.array(0.0, dtype=config.floatX)
logp_d0 = np.array(0.0, dtype=config.floatX)
beta_val, gamma_val, delta_val = [
np.array(val, dtype=config.floatX) for val in [0.277792, 0.135330, 0.108753]
]
C = np.array([3, 5, 8, 13, 21, 26, 10, 3], dtype=np.int32)
D = np.array([1, 2, 3, 7, 9, 11, 5, 1], dtype=np.int32)
test_input_vals = [
C,
D,
s0,
e0,
i0,
logp_c0,
logp_d0,
beta_val,
gamma_val,
delta_val,
]
compare_jax_and_py(out_fg, test_input_vals)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_scan_tap_output():
a_at = scalar("a")
def input_step_fn(y_tm1, y_tm3, a):
y_tm1.name = "y_tm1"
y_tm3.name = "y_tm3"
res = (y_tm1 + y_tm3) * a
res.name = "y_t"
return res
y_scan_at, _ = scan(
fn=input_step_fn,
outputs_info=[
{
"initial": at.as_tensor_variable(
np.r_[-1.0, 1.3, 0.0].astype(config.floatX)
),
"taps": [-1, -3],
},
],
non_sequences=[a_at],
n_steps=10,
name="y_scan",
)
y_scan_at.name = "y"
y_scan_at.owner.inputs[0].name = "y_all"
out_fg = FunctionGraph([a_at], [y_scan_at])
test_input_vals = [np.array(10.0).astype(config.floatX)]
compare_jax_and_py(out_fg, test_input_vals)
import jax
import numpy as np
import pytest
from packaging.version import parse as version_parse
import aesara.tensor as at
from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.tensor.shape import Shape, Shape_i, SpecifyShape, Unbroadcast, reshape
from aesara.tensor.type import iscalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_shape_ops():
x_np = np.zeros((20, 3))
x = Shape()(at.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [], must_be_device_array=False)
x = Shape_i(1)(at.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [], must_be_device_array=False)
@pytest.mark.xfail(
version_parse(jax.__version__) >= version_parse("0.2.12"),
reason="Omnistaging cannot be disabled",
)
def test_jax_specify_shape():
x_np = np.zeros((20, 3))
x = SpecifyShape()(at.as_tensor_variable(x_np), (20, 3))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
with config.change_flags(compute_test_value="off"):
x = SpecifyShape()(at.as_tensor_variable(x_np), *(2, 3))
x_fg = FunctionGraph([], [x])
with pytest.raises(AssertionError):
compare_jax_and_py(x_fg, [])
def test_jax_Reshape():
a = vector("a")
x = reshape(a, (2, 2))
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
# Test breaking "omnistaging" changes in JAX.
# See https://github.com/tensorflow/probability/commit/782d0c64eb774b9aac54a1c8488e4f1f96fbbc68
x = reshape(a, (a.shape[0] // 2, a.shape[0] // 2))
x_fg = FunctionGraph([a], [x])
with pytest.raises(
TypeError,
match="Shapes must be 1D sequences of concrete values of integer type",
):
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX)])
b = iscalar("b")
x = reshape(a, (b, b))
x_fg = FunctionGraph([a, b], [x])
with pytest.raises(
TypeError,
match="Shapes must be 1D sequences of concrete values of integer type",
):
compare_jax_and_py(x_fg, [np.r_[1.0, 2.0, 3.0, 4.0].astype(config.floatX), 2])
def test_jax_compile_ops():
x = DeepCopyOp()(at.as_tensor_variable(1.1))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
x_np = np.zeros((20, 1, 1))
x = Unbroadcast(0, 2)(at.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
x = ViewOp()(at.as_tensor_variable(x_np))
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
import numpy as np
import pytest
import aesara.tensor as at
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.tensor import nlinalg as at_nlinalg
from aesara.tensor import slinalg as at_slinalg
from aesara.tensor import subtensor as at_subtensor
from aesara.tensor.math import clip, cosh
from aesara.tensor.type import matrix, vector
from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_basic():
rng = np.random.default_rng(28494)
x = matrix("x")
y = matrix("y")
b = vector("b")
# `ScalarOp`
z = cosh(x**2 + y / 3.0)
# `[Inc]Subtensor`
out = at_subtensor.set_subtensor(z[0], -10.0)
out = at_subtensor.inc_subtensor(out[0, 1], 2.0)
out = out[:5, :3]
out_fg = FunctionGraph([x, y], [out])
test_input_vals = [
np.tile(np.arange(10), (10, 1)).astype(config.floatX),
np.tile(np.arange(10, 20), (10, 1)).astype(config.floatX),
]
(jax_res,) = compare_jax_and_py(out_fg, test_input_vals)
# Confirm that the `Subtensor` slice operations are correct
assert jax_res.shape == (5, 3)
# Confirm that the `IncSubtensor` operations are correct
assert jax_res[0, 0] == -10.0
assert jax_res[0, 1] == -8.0
out = clip(x, y, 5)
out_fg = FunctionGraph([x, y], [out])
compare_jax_and_py(out_fg, test_input_vals)
out = at.diagonal(x, 0)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]
)
out = at_slinalg.cholesky(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(
out_fg,
[
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX
)
],
)
# not sure why this isn't working yet with lower=False
out = at_slinalg.Cholesky(lower=False)(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(
out_fg,
[
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX
)
],
)
out = at_slinalg.solve(x, b)
out_fg = FunctionGraph([x, b], [out])
compare_jax_and_py(
out_fg,
[
np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX),
],
)
out = at.diag(b)
out_fg = FunctionGraph([b], [out])
compare_jax_and_py(out_fg, [np.arange(10).astype(config.floatX)])
out = at_nlinalg.det(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(
out_fg, [np.arange(10 * 10).reshape((10, 10)).astype(config.floatX)]
)
out = at_nlinalg.matrix_inverse(x)
out_fg = FunctionGraph([x], [out])
compare_jax_and_py(
out_fg,
[
(np.eye(10) + rng.standard_normal(size=(10, 10)) * 0.01).astype(
config.floatX
)
],
)
@pytest.mark.parametrize("check_finite", [False, True])
@pytest.mark.parametrize("lower", [False, True])
@pytest.mark.parametrize("trans", [0, 1, 2])
def test_jax_SolveTriangular(trans, lower, check_finite):
x = matrix("x")
b = vector("b")
out = at_slinalg.solve_triangular(
x,
b,
trans=trans,
lower=lower,
check_finite=check_finite,
)
out_fg = FunctionGraph([x, b], [out])
compare_jax_and_py(
out_fg,
[
np.eye(10).astype(config.floatX),
np.arange(10).astype(config.floatX),
],
)
import numpy as np
import pytest
import aesara.tensor.basic as at
from aesara.configdefaults import config
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import get_test_value
from aesara.tensor.type import matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
def test_jax_Alloc():
x = at.alloc(0.0, 2, 3)
x_fg = FunctionGraph([], [x])
(jax_res,) = compare_jax_and_py(x_fg, [])
assert jax_res.shape == (2, 3)
x = at.alloc(1.1, 2, 3)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
x = at.AllocEmpty("float32")(2, 3)
x_fg = FunctionGraph([], [x])
def compare_shape_dtype(x, y):
(x,) = x
(y,) = y
return x.shape == y.shape and x.dtype == y.dtype
compare_jax_and_py(x_fg, [], assert_fn=compare_shape_dtype)
a = scalar("a")
x = at.alloc(a, 20)
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [10.0])
a = vector("a")
x = at.alloc(a, 20, 10)
x_fg = FunctionGraph([a], [x])
compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)])
def test_jax_MakeVector():
x = at.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])
compare_jax_and_py(x_fg, [])
@pytest.mark.xfail(reason="jax.numpy.arange requires concrete inputs")
def test_arange_nonconcrete():
a = scalar("a")
a.tag.test_value = 10
out = at.arange(a)
fgraph = FunctionGraph([a], [out])
compare_jax_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs])
def test_jax_Join():
a = matrix("a")
b = matrix("b")
x = at.join(0, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0]].astype(config.floatX),
],
)
x = at.join(1, a, b)
x_fg = FunctionGraph([a, b], [x])
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0, 3.0]].astype(config.floatX),
np.c_[[4.0, 5.0, 6.0]].astype(config.floatX),
],
)
compare_jax_and_py(
x_fg,
[
np.c_[[1.0, 2.0], [3.0, 4.0]].astype(config.floatX),
np.c_[[5.0, 6.0]].astype(config.floatX),
],
)
def test_jax_eye():
"""Tests jaxification of the Eye operator"""
out = at.eye(3)
out_fg = FunctionGraph([], [out])
compare_jax_and_py(out_fg, [])
差异被折叠。
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论