提交 27f1ede5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Numba conversions for basic dimension manipulation Ops

上级 dcea4e0d
import operator
from functools import reduce, singledispatch from functools import reduce, singledispatch
from textwrap import indent
import numba import numba
import numpy as np import numpy as np
...@@ -6,14 +8,24 @@ import scipy ...@@ -6,14 +8,24 @@ import scipy
import scipy.special import scipy.special
from llvmlite.llvmpy.core import Type as llvm_Type from llvmlite.llvmpy.core import Type as llvm_Type
from numba import types from numba import types
from numba.core.errors import TypingError
from numba.cpython.unsafe.tuple import tuple_setitem
from numba.extending import box from numba.extending import box
from aesara.compile.ops import DeepCopyOp from aesara.compile.ops import DeepCopyOp, ViewOp
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.link.utils import compile_function_src, fgraph_to_python from aesara.link.utils import compile_function_src, fgraph_to_python
from aesara.scalar.basic import Composite, ScalarOp from aesara.scalar.basic import Cast, Composite, Identity, ScalarOp, Second
from aesara.tensor.elemwise import Elemwise from aesara.tensor.basic import (
Alloc,
AllocEmpty,
Rebroadcast,
ScalarFromTensor,
TensorFromScalar,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -51,6 +63,49 @@ def box_slice(typ, val, c): ...@@ -51,6 +63,49 @@ def box_slice(typ, val, c):
return slice_val return slice_val
@numba.generated_jit(nopython=True)
def to_scalar(x):
if isinstance(x, (numba.types.Number, numba.types.Boolean)):
return lambda x: x
elif isinstance(x, numba.types.Array):
return lambda x: x.item()
else:
raise TypingError(f"{x} must be a scalar compatible type.")
def create_tuple_creator(f, n):
"""Construct a compile-time ``tuple``-comprehension-like loop.
See https://github.com/numba/numba/issues/2771#issuecomment-414358902
"""
assert n > 0
f = numba.njit(f)
@numba.njit
def creator(args):
return (f(0, *args),)
for i in range(1, n):
@numba.njit
def creator(args, creator=creator, i=i):
return creator(args) + (f(i, *args),)
return numba.njit(lambda *args: creator(args))
def create_tuple_string(x):
args = ", ".join(x + ([""] if len(x) == 1 else []))
return f"({args})"
@numba.extending.overload(operator.contains)
def in_seq_empty_tuple(x, y):
if isinstance(x, types.Tuple) and not x.types:
return lambda x, y: False
@singledispatch @singledispatch
def numba_typify(data, dtype=None, **kwargs): def numba_typify(data, dtype=None, **kwargs):
return data return data
...@@ -66,7 +121,7 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): ...@@ -66,7 +121,7 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
def numba_funcify_FunctionGraph( def numba_funcify_FunctionGraph(
fgraph, fgraph,
node=None, node=None,
fgraph_name="jax_funcified_fgraph", fgraph_name="numba_funcified_fgraph",
**kwargs, **kwargs,
): ):
return fgraph_to_python( return fgraph_to_python(
...@@ -116,11 +171,11 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -116,11 +171,11 @@ def numba_funcify_Elemwise(op, node, **kwargs):
input_names = ", ".join([v.auto_name for v in node.inputs]) input_names = ", ".join([v.auto_name for v in node.inputs])
global_env = {"scalar_op": scalar_op_fn, "vectorize": numba.vectorize} global_env = {"scalar_op": scalar_op_fn, "numba_vectorize": numba.vectorize}
elemwise_fn_name = f"elemwise_{scalar_op_fn.__name__}" elemwise_fn_name = f"elemwise_{scalar_op_fn.__name__}"
elemwise_src = f""" elemwise_src = f"""
@vectorize @numba_vectorize
def {elemwise_fn_name}({input_names}): def {elemwise_fn_name}({input_names}):
return scalar_op({input_names}) return scalar_op({input_names})
""" """
...@@ -130,7 +185,7 @@ def {elemwise_fn_name}({input_names}): ...@@ -130,7 +185,7 @@ def {elemwise_fn_name}({input_names}):
@numba_funcify.register(Composite) @numba_funcify.register(Composite)
def numba_funcify_Composite(op, vectorize=True, **kwargs): def numba_funcify_Composite(op, node, **kwargs):
numba_impl = numba.njit(numba_funcify(op.fgraph, **kwargs)) numba_impl = numba.njit(numba_funcify(op.fgraph, **kwargs))
@numba.njit @numba.njit
...@@ -284,13 +339,253 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs): ...@@ -284,13 +339,253 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
@numba_funcify.register(MakeSlice) @numba_funcify.register(MakeSlice)
def numba_funcify_MakeSlice(op, **kwargs): def numba_funcify_MakeSlice(op, **kwargs):
"""
XXX: This requires a ``slice`` boxing implementation to work with Numba's
object mode.
"""
@numba.njit @numba.njit
def makeslice(*x): def makeslice(*x):
return slice(*x) return slice(*x)
return makeslice return makeslice
@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba.njit
def shape(x):
return np.asarray(np.shape(x))
return shape
@numba_funcify.register(Shape_i)
def numba_funcify_Shape_i(op, **kwargs):
i = op.i
@numba.njit
def shape_i(x):
return np.shape(x)[i]
return shape_i
@numba_funcify.register(TensorFromScalar)
def numba_funcify_TensorFromScalar(op, **kwargs):
@numba.njit
def tensor_from_scalar(x):
return np.array(x)
return tensor_from_scalar
@numba_funcify.register(ScalarFromTensor)
def numba_funcify_ScalarFromTensor(op, **kwargs):
@numba.njit
def scalar_from_tensor(x):
return x.item()
return scalar_from_tensor
@numba_funcify.register(AllocEmpty)
def numba_funcify_AllocEmpty(op, node, **kwargs):
global_env = {"np": np, "to_scalar": to_scalar, "dtype": op.dtype}
shape_var_names = [v.auto_name for v in node.inputs]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
[
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
),
" " * 4,
)
alloc_def_src = f"""
def allocempty({", ".join(shape_var_names)}):
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
return np.empty(scalar_shape, dtype)
"""
alloc_fn = compile_function_src(alloc_def_src, "allocempty", global_env)
return numba.njit(alloc_fn)
@numba_funcify.register(Alloc)
def numba_funcify_Alloc(op, node, **kwargs):
global_env = {"np": np, "to_scalar": to_scalar}
shape_var_names = [v.auto_name for v in node.inputs[1:]]
shape_var_item_names = [f"{name}_item" for name in shape_var_names]
shapes_to_items_src = indent(
"\n".join(
[
f"{item_name} = to_scalar({shape_name})"
for item_name, shape_name in zip(shape_var_item_names, shape_var_names)
]
),
" " * 4,
)
alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}):
val_np = np.asarray(val)
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
res = np.empty(scalar_shape, dtype=val_np.dtype)
res[...] = val_np
return res
"""
alloc_fn = compile_function_src(alloc_def_src, "alloc", global_env)
return numba.njit(alloc_fn)
@numba_funcify.register(Second)
def numba_funcify_Second(op, node, **kwargs):
@numba.njit
def second(x, y):
return y
return second
@numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, **kwargs):
shuffle = tuple(op.shuffle)
drop = tuple(op.drop)
augment = tuple(op.augment)
inplace = op.inplace
ndim_new_shape = len(shuffle) + len(augment)
create_zeros_tuple = create_tuple_creator(lambda _: 0, ndim_new_shape)
if len(shuffle) > 0:
@numba.njit
def populate_new_shape(i, j, new_shape, shuffle_shape):
if i in augment:
new_shape = tuple_setitem(new_shape, i, 1)
return j, new_shape
else:
new_shape = tuple_setitem(new_shape, i, shuffle_shape[j])
return j + 1, new_shape
else:
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression above is
# is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether.
@numba.njit
def populate_new_shape(i, j, new_shape, shuffle_shape):
new_shape = tuple_setitem(new_shape, i, 1)
return j, new_shape
@numba.njit
def dimshuffle_inner(x, shuffle):
res = np.transpose(x, shuffle + drop)
shuffle_shape = res.shape[: len(shuffle)]
new_shape = create_zeros_tuple()
j = 0
for i in range(len(new_shape)):
j, new_shape = populate_new_shape(i, j, new_shape, shuffle_shape)
# FIXME: Numba's `array.reshape` only accepts C arrays.
res_reshape = np.reshape(np.ascontiguousarray(res), new_shape)
if not inplace:
return res_reshape.copy()
else:
return res_reshape
# Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature:
# E
# E >>> getitem(UniTuple(int64 x 2), slice<a:b>)
# E
# E There are 22 candidate implementations:
# E - Of which 22 did not match due to:
# E Overload of function 'getitem': File: <numerous>: Line N/A.
# E With argument(s): '(UniTuple(int64 x 2), slice<a:b>)':
# E No match.
# ...(on this line)...
# E shuffle_shape = res.shape[: len(shuffle)]
@numba.njit
def dimshuffle(x):
return dimshuffle_inner(x, shuffle)
return dimshuffle
@numba_funcify.register(Rebroadcast)
def numba_funcify_Rebroadcast(op, **kwargs):
op_axis = tuple(op.axis.items())
@numba.njit
def rebroadcast(x):
for axis, value in numba.literal_unroll(op_axis):
if value and x.shape[axis] != 1:
raise ValueError(
("Dimension in Rebroadcast's input was supposed to be 1")
)
return x
return rebroadcast
@numba_funcify.register(Cast)
def numba_funcify_Cast(op, **kwargs):
dtype = op.o_type.dtype
@numba.njit
def cast(x):
return np.array(x, dtype=dtype)
return cast
@numba_funcify.register(Reshape)
def numba_funcify_Reshape(op, **kwargs):
ndim = op.ndim
# TODO: It might be possible/better to use
# `numba.np.unsafe.ndarray.to_fixed_tuple` here instead
create_zeros_tuple = create_tuple_creator(lambda _: 0, ndim)
@numba.njit
def reshape(x, shape):
new_shape = create_zeros_tuple()
for i in numba.literal_unroll(range(ndim)):
new_shape = tuple_setitem(new_shape, i, shape[i])
return np.reshape(x, new_shape)
return reshape
@numba_funcify.register(SpecifyShape)
def numba_funcify_SpecifyShape(op, **kwargs):
@numba.njit
def specifyshape(x, shape):
assert np.array_equal(x.shape, shape)
return x
return specifyshape
@numba_funcify.register(Identity)
@numba_funcify.register(ViewOp)
def numba_funcify_ViewOp(op, **kwargs):
@numba.njit
def viewop(x):
return x
return viewop
from functools import partial import contextlib
from unittest import mock from unittest import mock
import numpy as np import numpy as np
import pytest import pytest
import aesara.scalar as aes import aesara.scalar as aes
import aesara.scalar.basic as aesb
import aesara.tensor as aet import aesara.tensor as aet
import aesara.tensor.basic as aetb
from aesara import config from aesara import config
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.compile.ops import ViewOp
from aesara.compile.sharedvalue import SharedVariable from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.optdb import Query from aesara.graph.optdb import Query
from aesara.link.numba.linker import NumbaLinker from aesara.link.numba.linker import NumbaLinker
from aesara.scalar.basic import Composite from aesara.scalar.basic import Composite
from aesara.tensor import elemwise as aet_elemwise
from aesara.tensor import subtensor as aet_subtensor from aesara.tensor import subtensor as aet_subtensor
from aesara.tensor.elemwise import Elemwise from aesara.tensor.elemwise import Elemwise
from aesara.tensor.type import scalar from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"]) opts = Query(include=[None], exclude=["cxx_only", "BlasOpt"])
...@@ -24,6 +29,17 @@ numba_mode = Mode(NumbaLinker(), opts) ...@@ -24,6 +29,17 @@ numba_mode = Mode(NumbaLinker(), opts)
py_mode = Mode("py", opts) py_mode = Mode("py", opts)
def set_test_value(x, v):
x.tag.test_value = v
return x
def compare_shape_dtype(x, y):
(x,) = x
(y,) = y
return x.shape == y.shape and x.dtype == y.dtype
def compare_numba_and_py( def compare_numba_and_py(
fgraph, fgraph,
inputs, inputs,
...@@ -47,9 +63,19 @@ def compare_numba_and_py( ...@@ -47,9 +63,19 @@ def compare_numba_and_py(
""" """
if assert_fn is None: if assert_fn is None:
assert_fn = partial(np.testing.assert_allclose, rtol=1e-4)
def assert_fn(x, y):
return np.testing.assert_allclose(x, y, rtol=1e-4) and compare_shape_dtype(
x, y
)
fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)] fn_inputs = [i for i in fgraph.inputs if not isinstance(i, SharedVariable)]
aesara_py_fn = function(
fn_inputs, fgraph.outputs, mode=py_mode, accept_inplace=True
)
py_res = aesara_py_fn(*inputs)
aesara_numba_fn = function( aesara_numba_fn = function(
fn_inputs, fn_inputs,
fgraph.outputs, fgraph.outputs,
...@@ -76,11 +102,6 @@ def compare_numba_and_py( ...@@ -76,11 +102,6 @@ def compare_numba_and_py(
) )
_ = aesara_numba_fn(*inputs) _ = aesara_numba_fn(*inputs)
aesara_py_fn = function(
fn_inputs, fgraph.outputs, mode=py_mode, accept_inplace=True
)
py_res = aesara_py_fn(*inputs)
if len(fgraph.outputs) > 1: if len(fgraph.outputs) > 1:
for j, p in zip(numba_res, py_res): for j, p in zip(numba_res, py_res):
assert_fn(j, p) assert_fn(j, p)
...@@ -114,7 +135,7 @@ def test_Elemwise(inputs, input_vals, output_fn): ...@@ -114,7 +135,7 @@ def test_Elemwise(inputs, input_vals, output_fn):
"inputs, input_values", "inputs, input_values",
[ [
( (
[scalar("x"), scalar("y")], [aet.scalar("x"), aet.scalar("y")],
[np.array(10).astype(config.floatX), np.array(20).astype(config.floatX)], [np.array(10).astype(config.floatX), np.array(20).astype(config.floatX)],
), ),
], ],
...@@ -296,3 +317,319 @@ def test_AdvancedIncSubtensor(x, y, indices): ...@@ -296,3 +317,319 @@ def test_AdvancedIncSubtensor(x, y, indices):
assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor) assert isinstance(out_aet.owner.op, aet_subtensor.AdvancedIncSubtensor)
out_fg = FunctionGraph([x_at], [out_aet]) out_fg = FunctionGraph([x_at], [out_aet])
compare_numba_and_py(out_fg, [x.data]) compare_numba_and_py(out_fg, [x.data])
@pytest.mark.parametrize(
"x, i",
[
(np.zeros((20, 3)), 1),
],
)
def test_Shape(x, i):
g = Shape()(aet.as_tensor_variable(x))
g_fg = FunctionGraph([], [g])
compare_numba_and_py(g_fg, [])
g = Shape_i(i)(aet.as_tensor_variable(x))
g_fg = FunctionGraph([], [g])
compare_numba_and_py(g_fg, [])
@pytest.mark.parametrize(
"v, shape",
[
(0.0, (2, 3)),
(1.1, (2, 3)),
(set_test_value(aet.scalar("a"), np.array(10.0, dtype=config.floatX)), (20,)),
(set_test_value(aet.vector("a"), np.ones(10, dtype=config.floatX)), (20, 10)),
],
)
def test_Alloc(v, shape):
g = aet.alloc(v, *shape)
g_fg = FunctionGraph(outputs=[g])
(numba_res,) = compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
assert numba_res.shape == shape
def test_AllocEmpty():
x = aet.empty((2, 3), dtype="float32")
x_fg = FunctionGraph([], [x])
# We need cannot compare the values in the arrays, only the shapes and
# dtypes
compare_numba_and_py(x_fg, [], assert_fn=compare_shape_dtype)
@pytest.mark.parametrize(
"v, new_order, inplace",
[
# `{'drop': [], 'shuffle': [], 'augment': [0, 1]}`
(
set_test_value(
aet.lscalar(name="a"),
np.array(1, dtype=np.int64),
),
("x", "x"),
True,
),
# I.e. `a_aet.T`
# `{'drop': [], 'shuffle': [1, 0], 'augment': []}`
(
set_test_value(
aet.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
),
(1, 0),
True,
),
# `{'drop': [], 'shuffle': [0, 1], 'augment': [2]}`
(
set_test_value(
aet.matrix("a"), np.array([[1.0, 2.0], [3.0, 4.0]], dtype=config.floatX)
),
(1, 0, "x"),
True,
),
# `{'drop': [1], 'shuffle': [2, 0], 'augment': [0, 2, 4]}`
(
set_test_value(
aet.tensor(config.floatX, [False, True, False], name="a"),
np.array([[[1.0, 2.0]], [[3.0, 4.0]]], dtype=config.floatX),
),
("x", 2, "x", 0, "x"),
True,
),
# I.e. `a_aet.dimshuffle((0,))`
# `{'drop': [1], 'shuffle': [0], 'augment': []}`
(
set_test_value(
aet.tensor(config.floatX, [False, True], name="a"),
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
),
(0,),
True,
),
(
set_test_value(
aet.tensor(config.floatX, [False, True], name="a"),
np.array([[1.0], [2.0], [3.0], [4.0]], dtype=config.floatX),
),
(0,),
True,
),
],
)
def test_Dimshuffle(v, new_order, inplace):
g = aet_elemwise.DimShuffle(v.broadcastable, new_order, inplace=inplace)(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"v", [set_test_value(aes.float64(), np.array(1.0, dtype="float64"))]
)
def test_TensorFromScalar(v):
g = aetb.TensorFromScalar()(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"v",
[
set_test_value(aet.scalar(), np.array(1.0, dtype=config.floatX)),
],
)
def test_ScalarFromTensor(v):
g = aetb.ScalarFromTensor()(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"v, axis, fails",
[
(
set_test_value(aet.matrix(), np.array([[1.0]], dtype=config.floatX)),
[(0, True), (1, True)],
False,
),
(
set_test_value(aet.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
[(0, True), (1, False)],
False,
),
(
set_test_value(aet.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
[(0, True), (1, True)],
True,
),
],
)
def test_Rebroadcast(v, axis, fails):
g = aetb.Rebroadcast(*axis)(v)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if not fails else pytest.raises(ValueError)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"v, dtype",
[
(set_test_value(aet.fscalar(), np.array(1.0, dtype="float32")), aesb.float64),
(set_test_value(aet.dscalar(), np.array(1.0, dtype="float64")), aesb.float32),
],
)
def test_Cast(v, dtype):
g = aesb.Cast(dtype)(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"v, shape, ndim",
[
(set_test_value(aet.vector(), np.arange(4, dtype=config.floatX)), (2, 2), 2),
(
set_test_value(aet.vector(), np.arange(4, dtype=config.floatX)),
set_test_value(aet.lvector(), np.array([2, 2], dtype="int64")),
2,
),
],
)
def test_Reshape(v, shape, ndim):
g = Reshape(ndim)(v, shape)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"v, shape, fails",
[
(
set_test_value(aet.matrix(), np.array([[1.0]], dtype=config.floatX)),
(1, 1),
False,
),
(
set_test_value(aet.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, 1),
True,
),
],
)
def test_SpecifyShape(v, shape, fails):
g = SpecifyShape()(v, shape)
g_fg = FunctionGraph(outputs=[g])
cm = contextlib.suppress() if not fails else pytest.raises(AssertionError)
with cm:
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"v",
[
set_test_value(aet.vector(), np.arange(4, dtype=config.floatX)),
],
)
def test_ViewOp(v):
g = ViewOp()(v)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"x, y",
[
(
set_test_value(aet.lvector(), np.arange(4, dtype="int64")),
set_test_value(aet.dvector(), np.arange(4, dtype="float64")),
),
(
set_test_value(
aet.dmatrix(), np.arange(4, dtype="float64").reshape((2, 2))
),
set_test_value(aet.lscalar(), np.array(4, dtype="int64")),
),
],
)
def test_Second(x, y):
# We use the `Elemwise`-wrapped version of `Second`
g = aet.second(x, y)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论