提交 fc7922c0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Numba conversions for math Ops

上级 e0ab0b46
import operator
import warnings
from functools import reduce, singledispatch
from numbers import Number
from textwrap import indent
from typing import Union
import numba
import numpy as np
......@@ -32,6 +34,7 @@ from aesara.scalar.basic import (
ScalarOp,
Second,
)
from aesara.scalar.basic_scipy import Softplus
from aesara.tensor.basic import (
Alloc,
AllocDiag,
......@@ -45,6 +48,7 @@ from aesara.tensor.basic import (
ScalarFromTensor,
TensorFromScalar,
)
from aesara.tensor.blas import BatchedDot
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.extra_ops import (
Bartlett,
......@@ -58,7 +62,11 @@ from aesara.tensor.extra_ops import (
Unique,
UnravelIndex,
)
from aesara.tensor.math import Dot, MaxAndArgmax
from aesara.tensor.nlinalg import SVD, Det, Eig, Eigh, MatrixInverse, QRFull
from aesara.tensor.nnet.basic import LogSoftmax, Softmax
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.slinalg import Cholesky, Solve
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
......@@ -382,74 +390,99 @@ def {elemwise_fn_name}({input_names}):
return elemwise_fn
@numba_funcify.register(CAReduce)
def numba_funcify_CAReduce(op, node, **kwargs):
axes = op.axis
if axes is None:
axes = list(range(node.inputs[0].ndim))
def create_axis_reducer(
reduce_fn: numba.np.ufunc.dufunc.DUFunc,
identity: Union[np.ndarray, Number],
axis: int,
ndim: int,
dtype: numba.types.Type,
keepdims: bool = False,
) -> numba.core.dispatcher.Dispatcher:
"""Create a Numba JITed function that performs a NumPy reduction on a given axis.
Parameters
==========
reduce_fn:
The Numba ``ufunc`` representing a binary op that can perform the
reduction on arbitrary ``ndarray``s.
identity:
The identity value for the reduction.
axis:
The axis to reduce.
ndim:
The number of dimensions of the result.
dtype:
The data type of the result.
keepdims:
Determines whether or not the reduced dimension is retained.
"""
if ndim > 1:
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype
else:
acc_dtype = node.outputs[0].type.dtype
if keepdims:
np_acc_dtype = np.dtype(acc_dtype)
@numba.njit(inline="always")
def set_out_dims(x):
return np.expand_dims(x, axis)
scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype)
else:
acc_dtype = numba.np.numpy_support.from_dtype(np_acc_dtype)
@numba.njit(inline="always")
def set_out_dims(x):
return x
scalar_nfunc_spec = op.scalar_op.nfunc_spec
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
)
# We construct a dummy `Apply` that has the minimum required number of
# inputs for the scalar `Op`. Without this, we would get a scalar function
# with too few arguments.
dummy_node = Apply(
op,
[tensor(acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
)
elemwise_fn = numba_funcify_Elemwise(op, dummy_node, use_signature=True, **kwargs)
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
def create_careduce_axis(axis, ndim):
if ndim > 1:
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
)
@numba.njit(boundscheck=False)
def careduce_axis(x):
res_shape = res_shape_tuple_ctor(x.shape)
x_axis_first = x.transpose(reaxis_first)
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
res = np.full(res_shape, to_scalar(identity), dtype=dtype)
for m in range(x.shape[axis]):
reduce_fn(res, x_axis_first[m], res)
@numba.njit(boundscheck=False)
def careduce_axis(x):
res_shape = res_shape_tuple_ctor(x.shape)
x_axis_first = x.transpose(reaxis_first)
return set_out_dims(res)
res = np.full(res_shape, scalar_op_identity.item(), dtype=acc_dtype)
for m in range(x.shape[axis]):
elemwise_fn(res, x_axis_first[m], res)
else:
return res
if keepdims:
@numba.njit(inline="always")
def set_out_dims(x):
return np.array([x], dtype)
else:
@numba.njit(boundscheck=False)
def careduce_axis(x):
res = scalar_op_identity.item()
for val in x:
res = elemwise_fn(res, val)
return res
@numba.njit(inline="always")
def set_out_dims(x):
return direct_cast(x, dtype)
return careduce_axis
@numba.njit(boundscheck=False)
def careduce_axis(x):
res = to_scalar(identity)
for val in x:
res = reduce_fn(res, val)
return set_out_dims(res)
careduce_fn_name = f"careduce_{get_name_for_object(elemwise_fn)}"
ndim = node.inputs[0].ndim
return careduce_axis
def create_multiaxis_reducer(
reduce_fn, identity, axes, ndim, dtype, input_name="input"
):
careduce_fn_name = f"careduce_{get_name_for_object(reduce_fn)}"
careduce_axes_fns = ()
to_reduce = reversed(sorted(axes))
careduce_lines_src = []
input_name = get_name_for_object(node.inputs[0])
var_name = input_name
for i, axis in enumerate(to_reduce):
careduce_axes_fns += (create_careduce_axis(axis - i, ndim),)
careduce_axes_fns += (
create_axis_reducer(reduce_fn, identity, axis - i, ndim, dtype),
)
ndim -= 1
last_var_name = var_name
var_name = f"axis_{i}_res"
......@@ -467,6 +500,43 @@ def {careduce_fn_name}({input_name}):
global_env = {"careduce_axes_fns": careduce_axes_fns}
careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, global_env)
return careduce_fn
@numba_funcify.register(CAReduce)
def numba_funcify_CAReduce(op, node, **kwargs):
axes = op.axis
if axes is None:
axes = list(range(node.inputs[0].ndim))
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype
else:
acc_dtype = node.outputs[0].type.dtype
np_acc_dtype = np.dtype(acc_dtype)
scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype)
acc_dtype = numba.np.numpy_support.from_dtype(np_acc_dtype)
scalar_nfunc_spec = op.scalar_op.nfunc_spec
# We construct a dummy `Apply` that has the minimum required number of
# inputs for the scalar `Op`. Without this, we would get a scalar function
# with too few arguments.
dummy_node = Apply(
op,
[tensor(acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
)
elemwise_fn = numba_funcify_Elemwise(op, dummy_node, use_signature=True, **kwargs)
input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim
careduce_fn = create_multiaxis_reducer(
elemwise_fn, scalar_op_identity, axes, ndim, acc_dtype, input_name=input_name
)
return numba.njit(careduce_fn)
......@@ -850,7 +920,14 @@ def numba_funcify_Rebroadcast(op, **kwargs):
@numba.extending.intrinsic
def direct_cast(typingctx, val, typ):
casted = typ.instance_type
if isinstance(typ, numba.types.TypeRef):
casted = typ.instance_type
elif isinstance(typ, numba.types.DTypeSpec):
casted = typ.dtype
else:
casted = typ
sig = casted(casted, typ)
def codegen(context, builder, signature, args):
......@@ -862,7 +939,7 @@ def direct_cast(typingctx, val, typ):
@numba_funcify.register(Cast)
def numba_funcify_Cast(op, **kwargs):
def numba_funcify_Cast(op, node, **kwargs):
dtype = np.dtype(op.o_type.dtype)
dtype = numba.np.numpy_support.from_dtype(dtype)
......@@ -1315,3 +1392,415 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
return np.searchsorted(a, v, side)
return searchsorted
def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
if any(i.type.numpy_dtype.kind in "ib" for i in inputs):
args_dtype = np.dtype(f"f{out_dtype.itemsize}")
@numba.njit(inline="always")
def inputs_cast(x):
return x.astype(args_dtype)
else:
@numba.njit(inline="always")
def inputs_cast(x):
return x
return inputs_cast
@numba_funcify.register(Dot)
def numba_funcify_Dot(op, node, **kwargs):
# Numba's `np.dot` does not support integer dtypes, so we need to cast to
# float.
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
def dot(x, y):
return np.dot(inputs_cast(x), inputs_cast(y)).astype(out_dtype)
return dot
@numba_funcify.register(Softmax)
def numba_funcify_Softmax(op, node, **kwargs):
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
# np.max(x, axis=1)
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype)
# np.sum(x, axis=1)
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype)
@numba.njit
def softmax(x):
z = np.expand_dims(reduce_max(x), -1)
e_x = np.exp(x - z)
w = np.expand_dims(reduce_sum(e_x), -1)
sm = e_x / w
return sm
return softmax
@numba_funcify.register(LogSoftmax)
def numba_funcify_LogSoftmax(op, node, **kwargs):
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
# np.max(x, axis=1)
reduce_max = create_axis_reducer(np.maximum, -np.inf, 1, x_at.ndim, x_dtype)
# np.sum(x, axis=1, keepdims=True)
reduce_sum = create_axis_reducer(np.add, 0.0, 1, x_at.ndim, x_dtype, keepdims=True)
@numba.njit
def log_softmax(x):
xdev = x - np.expand_dims(reduce_max(x), -1)
lsm = xdev - np.log(reduce_sum(np.exp(xdev)))
return lsm
return log_softmax
@numba_funcify.register(Softplus)
def numba_funcify_Softplus(op, node, **kwargs):
x_dtype = np.dtype(node.inputs[0].dtype)
@numba.njit
def softplus(x):
if x < -37.0:
return direct_cast(np.exp(x), x_dtype)
elif x < 18.0:
return direct_cast(np.log1p(np.exp(x)), x_dtype)
elif x < 33.3:
return direct_cast(x + np.exp(-x), x_dtype)
else:
return direct_cast(x, x_dtype)
return softplus
def create_axis_apply_fn(fn, axis, ndim, dtype):
reaxis_first = tuple(i for i in range(ndim) if i != axis) + (axis,)
@numba.njit(boundscheck=False)
def axis_apply_fn(x):
x_reaxis = x.transpose(reaxis_first)
res = np.zeros(x_reaxis.shape[:-1], dtype=dtype)
for m in np.ndindex(res.shape):
v = fn(x_reaxis[m])
res[m] = v
return res
return axis_apply_fn
@numba_funcify.register(MaxAndArgmax)
def numba_funcify_MaxAndArgmax(op, node, **kwargs):
axis = op.axis
x_at = node.inputs[0]
x_dtype = x_at.type.numpy_dtype
x_dtype = numba.np.numpy_support.from_dtype(x_dtype)
x_ndim = x_at.ndim
if x_ndim == 0:
@numba.njit(inline="always")
def maxandargmax(x):
return x, 0
else:
axes = tuple(int(ax) for ax in axis)
# NumPy does not support multiple axes for argmax; this is a
# work-around
keep_axes = tuple(i for i in range(x_ndim) if i not in axes)
reduce_max = numba.njit(
create_multiaxis_reducer(np.maximum, -np.inf, axes, x_ndim, x_dtype)
)
reduced_x_ndim = x_ndim - len(axes) + 1
argmax_axis = create_axis_apply_fn(
np.argmax, reduced_x_ndim - 1, reduced_x_ndim, np.int64
)
reaxis_order = keep_axes + axes
sl1 = slice(None, len(keep_axes))
sl2 = slice(len(keep_axes), None)
@numba.njit
def maxandargmax(x):
max_res = reduce_max(x)
# Not-reduced axes in front
transposed_x = np.ascontiguousarray(np.transpose(x, reaxis_order))
kept_shape = transposed_x.shape[sl1]
reduced_shape = transposed_x.shape[sl2]
reduced_size = 1
for s in reduced_shape:
reduced_size *= s
# Numpy.prod returns 1.0 when arg is empty, so we cast it to int64
# Otherwise reshape would complain citing float arg
new_shape = kept_shape + (reduced_size,)
reshaped_x = transposed_x.reshape(new_shape)
max_idx_res = argmax_axis(reshaped_x)
return max_res, max_idx_res
return maxandargmax
@numba_funcify.register(Cholesky)
def numba_funcify_Cholesky(op, node, **kwargs):
lower = op.lower
out_dtype = node.outputs[0].type.numpy_dtype
if lower:
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
def cholesky(a):
return np.linalg.cholesky(inputs_cast(a)).astype(out_dtype)
else:
# TODO: Use SciPy's BLAS/LAPACK Cython wrappers.
warnings.warn(
(
"Numba will use object mode to allow the "
"`lower` argument to `scipy.linalg.cholesky`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def cholesky(a):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.cholesky(a, lower=lower).astype(out_dtype)
return ret
return cholesky
@numba_funcify.register(Solve)
def numba_funcify_Solve(op, node, **kwargs):
if op.A_structure == "lower_triangular" or op.A_structure == "upper_triangular":
lower = op.A_structure == "lower_triangular"
warnings.warn(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def solve(a, b):
with numba.objmode(ret=ret_sig):
ret = scipy.linalg.solve_triangular(a, b, lower=lower)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
def solve(a, b):
return np.linalg.solve(inputs_cast(a), inputs_cast(b)).astype(out_dtype)
return solve
@numba_funcify.register(Det)
def numba_funcify_Det(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
def det(x):
return direct_cast(np.linalg.det(inputs_cast(x)), out_dtype)
return det
@numba_funcify.register(Eig)
def numba_funcify_Eig(op, node, **kwargs):
out_dtype_1 = node.outputs[0].type.numpy_dtype
out_dtype_2 = node.outputs[1].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype_1)
@numba.njit
def eig(x):
out = np.linalg.eig(inputs_cast(x))
return (out[0].astype(out_dtype_1), out[1].astype(out_dtype_2))
return eig
@numba_funcify.register(Eigh)
def numba_funcify_Eigh(op, node, **kwargs):
uplo = op.UPLO
if uplo != "L":
warnings.warn(
(
"Numba will use object mode to allow the "
"`UPLO` argument to `numpy.linalg.eigh`."
),
UserWarning,
)
out_dtypes = tuple(o.type.numpy_dtype for o in node.outputs)
ret_sig = numba.types.Tuple(
[get_numba_type(node.outputs[0].type), get_numba_type(node.outputs[1].type)]
)
@numba.njit
def eigh(x):
with numba.objmode(ret=ret_sig):
out = np.linalg.eigh(x, UPLO=uplo)
ret = (out[0].astype(out_dtypes[0]), out[1].astype(out_dtypes[1]))
return ret
else:
@numba.njit
def eigh(x):
return np.linalg.eigh(x)
return eigh
@numba_funcify.register(MatrixInverse)
def numba_funcify_MatrixInverse(op, node, **kwargs):
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
def matrix_inverse(x):
return np.linalg.inv(inputs_cast(x)).astype(out_dtype)
return matrix_inverse
@numba_funcify.register(QRFull)
def numba_funcify_QRFull(op, node, **kwargs):
mode = op.mode
if mode != "reduced":
warnings.warn(
(
"Numba will use object mode to allow the "
"`mode` argument to `numpy.linalg.qr`."
),
UserWarning,
)
if len(node.outputs) > 1:
ret_sig = numba.types.Tuple([get_numba_type(o.type) for o in node.outputs])
else:
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def qr_full(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.qr(x, mode=mode)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
def qr_full(x):
res = np.linalg.qr(inputs_cast(x))
return res
return qr_full
@numba_funcify.register(SVD)
def numba_funcify_SVD(op, node, **kwargs):
full_matrices = op.full_matrices
compute_uv = op.compute_uv
if not compute_uv:
warnings.warn(
(
"Numba will use object mode to allow the "
"`compute_uv` argument to `numpy.linalg.svd`."
),
UserWarning,
)
ret_sig = get_numba_type(node.outputs[0].type)
@numba.njit
def svd(x):
with numba.objmode(ret=ret_sig):
ret = np.linalg.svd(x, full_matrices, compute_uv)
return ret
else:
out_dtype = node.outputs[0].type.numpy_dtype
inputs_cast = int_to_float_fn(node.inputs, out_dtype)
@numba.njit
def svd(x):
return np.linalg.svd(inputs_cast(x), full_matrices)
return svd
@numba_funcify.register(BatchedDot)
def numba_funcify_BatchedDot(op, node, **kwargs):
dtype = node.outputs[0].type.numpy_dtype
@numba.njit
def batched_dot(x, y):
shape = x.shape[:-1] + y.shape[2:]
z0 = np.empty(shape, dtype=dtype)
for i in range(z0.shape[0]):
z0[i] = np.dot(x[i], y[i])
return z0
return batched_dot
# NOTE: The remaining `aesara.tensor.blas` `Op`s appear unnecessary, because
# they're only used to optimize basic `Dot` nodes, and those GEMV and GEMM
# optimizations are apparently already performed by Numba
......@@ -8,8 +8,11 @@ import pytest
import aesara.scalar as aes
import aesara.scalar.basic as aesb
import aesara.scalar.basic_scipy as aes_sci
import aesara.tensor as aet
import aesara.tensor.basic as aetb
import aesara.tensor.math as aem
import aesara.tensor.nnet.basic as nnetb
from aesara import config
from aesara.compile.function import function
from aesara.compile.mode import Mode
......@@ -23,8 +26,9 @@ from aesara.graph.type import Type
from aesara.link.numba.dispatch import create_numba_signature, get_numba_type
from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite
from aesara.tensor import blas
from aesara.tensor import elemwise as aet_elemwise
from aesara.tensor import extra_ops
from aesara.tensor import extra_ops, nlinalg, slinalg
from aesara.tensor import subtensor as aet_subtensor
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
......@@ -70,6 +74,8 @@ opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts)
np.random.seed(42849)
def set_test_value(x, v):
x.tag.test_value = v
......@@ -1667,3 +1673,642 @@ def test_BroadcastTo(x, shape, exc):
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, y, exc",
[
(
set_test_value(
aet.matrix(), np.random.random(size=(3, 2)).astype(config.floatX)
),
set_test_value(
aet.vector(), np.random.random(size=(2,)).astype(config.floatX)
),
None,
),
(
set_test_value(aet.lmatrix(), np.random.poisson(size=(3, 2))),
set_test_value(
aet.fvector(), np.random.random(size=(2,)).astype("float32")
),
None,
),
],
)
def test_Dot(x, y, exc):
g = aem.Dot()(x, y)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, exc",
[
(
set_test_value(
aet.vector(), np.random.random(size=(2,)).astype(config.floatX)
),
None,
),
(
set_test_value(
aet.matrix(), np.random.random(size=(2, 3)).astype(config.floatX)
),
None,
),
],
)
def test_Softmax(x, exc):
g = nnetb.Softmax()(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, exc",
[
(
set_test_value(
aet.vector(), np.random.random(size=(2,)).astype(config.floatX)
),
None,
),
(
set_test_value(
aet.matrix(), np.random.random(size=(2, 3)).astype(config.floatX)
),
None,
),
],
)
def test_LogSoftmax(x, exc):
g = nnetb.LogSoftmax()(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, exc",
[
(
set_test_value(aes.float64(), np.array(0.0, dtype="float64")),
None,
),
(
set_test_value(aes.float64(), np.array(-32.0, dtype="float64")),
None,
),
(
set_test_value(aes.float64(), np.array(-40.0, dtype="float64")),
None,
),
(
set_test_value(aes.float64(), np.array(32.0, dtype="float64")),
None,
),
(
set_test_value(aes.float64(), np.array(40.0, dtype="float64")),
None,
),
(
set_test_value(aes.int64(), np.array(32, dtype="int64")),
None,
),
],
)
def test_Softplus(x, exc):
g = aes_sci.Softplus(aes.upgrade_to_float)(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, axes, exc",
[
(
set_test_value(aet.dscalar(), np.array(0.0, dtype="float64")),
[],
None,
),
(
set_test_value(
aet.dvector(), np.random.random(size=(3,)).astype("float64")
),
[0],
None,
),
(
set_test_value(
aet.dmatrix(), np.random.random(size=(3, 2)).astype("float64")
),
[0],
None,
),
(
set_test_value(
aet.dmatrix(), np.random.random(size=(3, 2)).astype("float64")
),
[0, 1],
None,
),
],
)
def test_MaxAndArgmax(x, axes, exc):
g = aem.MaxAndArgmax(axes)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, lower, exc",
[
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
True,
None,
),
(
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
np.random.randint(1, 10, size=(3, 3)).astype("int64")
),
),
True,
None,
),
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
False,
UserWarning,
),
],
)
def test_Cholesky(x, lower, exc):
g = slinalg.Cholesky(lower)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"A, x, lower, exc",
[
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
set_test_value(
aet.dvector(), np.random.random(size=(3,)).astype("float64")
),
"general",
None,
),
(
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
np.random.randint(1, 10, size=(3, 3)).astype("int64")
),
),
set_test_value(
aet.dvector(), np.random.random(size=(3,)).astype("float64")
),
"general",
None,
),
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
set_test_value(
aet.dvector(), np.random.random(size=(3,)).astype("float64")
),
"lower_triangular",
UserWarning,
),
],
)
def test_Solve(A, x, lower, exc):
g = slinalg.Solve(lower)(A, x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, exc",
[
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
None,
),
(
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(np.random.poisson(size=(3, 3)).astype("int64")),
),
None,
),
],
)
def test_Det(x, exc):
g = nlinalg.Det()(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, exc",
[
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
None,
),
(
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
np.random.randint(1, 10, size=(3, 3)).astype("int64")
),
),
None,
),
],
)
def test_Eig(x, exc):
g = nlinalg.Eig()(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, uplo, exc",
[
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
"L",
None,
),
(
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
np.random.randint(1, 10, size=(3, 3)).astype("int64")
),
),
"U",
UserWarning,
),
],
)
def test_Eigh(x, uplo, exc):
g = nlinalg.Eigh(uplo)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, exc",
[
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
None,
),
(
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
np.random.randint(1, 10, size=(3, 3)).astype("int64")
),
),
None,
),
],
)
def test_MatrixInverse(x, exc):
g = nlinalg.MatrixInverse()(x)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, mode, exc",
[
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
"reduced",
None,
),
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
"r",
None,
),
(
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
np.random.randint(1, 10, size=(3, 3)).astype("int64")
),
),
"reduced",
None,
),
(
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
np.random.randint(1, 10, size=(3, 3)).astype("int64")
),
),
"complete",
UserWarning,
),
],
)
def test_QRFull(x, mode, exc):
g = nlinalg.QRFull(mode)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, full_matrices, compute_uv, exc",
[
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
True,
True,
None,
),
(
set_test_value(
aet.dmatrix(),
(lambda x: x.T.dot(x))(np.random.random(size=(3, 3)).astype("float64")),
),
False,
True,
None,
),
(
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
np.random.randint(1, 10, size=(3, 3)).astype("int64")
),
),
True,
True,
None,
),
(
set_test_value(
aet.lmatrix(),
(lambda x: x.T.dot(x))(
np.random.randint(1, 10, size=(3, 3)).astype("int64")
),
),
True,
False,
UserWarning,
),
],
)
def test_SVD(x, full_matrices, compute_uv, exc):
g = nlinalg.SVD(full_matrices, compute_uv)(x)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, y, exc",
[
(
set_test_value(
aet.dmatrix(),
np.random.random(size=(3, 3)).astype("float64"),
),
set_test_value(
aet.dmatrix(),
np.random.random(size=(3, 3)).astype("float64"),
),
None,
),
(
set_test_value(
aet.dmatrix(),
np.random.random(size=(3, 3)).astype("float64"),
),
set_test_value(
aet.lmatrix(),
np.random.poisson(size=(3, 3)).astype("int64"),
),
None,
),
],
)
def test_BatchedDot(x, y, exc):
g = blas.BatchedDot()(x, y)
if isinstance(g, list):
g_fg = FunctionGraph(outputs=g)
else:
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if exc is None else pytest.warns(exc)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论