提交 18d1a7a1 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Implement CAReduce conversions for Numba

上级 487ce550
......@@ -40,7 +40,7 @@ from aesara.tensor.basic import (
ScalarFromTensor,
TensorFromScalar,
)
from aesara.tensor.elemwise import DimShuffle, Elemwise
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from aesara.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from aesara.tensor.subtensor import (
AdvancedIncSubtensor,
......@@ -50,7 +50,7 @@ from aesara.tensor.subtensor import (
IncSubtensor,
Subtensor,
)
from aesara.tensor.type import TensorType
from aesara.tensor.type import TensorType, tensor
from aesara.tensor.type_other import MakeSlice
......@@ -220,12 +220,18 @@ def {scalar_op_fn_name}({input_names}):
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs)
input_names = ", ".join([v.auto_name for v in node.inputs])
global_env = {"scalar_op": scalar_op_fn, "numba_vectorize": numba.vectorize}
if use_signature:
signature = [create_numba_signature(node, force_scalar=True)]
else:
signature = []
numba_vectorize = numba.vectorize(signature, identity=identity)
global_env = {"scalar_op": scalar_op_fn, "numba_vectorize": numba_vectorize}
elemwise_fn_name = f"elemwise_{get_name_for_object(scalar_op_fn)}"
elemwise_src = f"""
......@@ -238,6 +244,95 @@ def {elemwise_fn_name}({input_names}):
return elemwise_fn
@numba_funcify.register(CAReduce)
def numba_funcify_CAReduce(op, node, **kwargs):
axes = op.axis
if axes is None:
axes = list(range(node.inputs[0].ndim))
if hasattr(op, "acc_dtype") and op.acc_dtype is not None:
acc_dtype = op.acc_dtype
else:
acc_dtype = node.outputs[0].type.dtype
np_acc_dtype = np.dtype(acc_dtype)
scalar_op_identity = np.asarray(op.scalar_op.identity, dtype=np_acc_dtype)
acc_dtype = numba.np.numpy_support.from_dtype(np_acc_dtype)
scalar_nfunc_spec = op.scalar_op.nfunc_spec
# We construct a dummy `Apply` that has the minimum required number of
# inputs for the scalar `Op`. Without this, we would get a scalar function
# with too few arguments.
dummy_node = Apply(
op,
[tensor(acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
)
elemwise_fn = numba_funcify_Elemwise(op, dummy_node, use_signature=True, **kwargs)
def create_careduce_axis(axis, ndim):
if ndim > 1:
res_shape_tuple_ctor = create_tuple_creator(
lambda i, shape: shape[i] if i < axis else shape[i + 1], ndim - 1
)
reaxis_first = (axis,) + tuple(i for i in range(ndim) if i != axis)
@numba.njit(boundscheck=False)
def careduce_axis(x):
res_shape = res_shape_tuple_ctor(x.shape)
x_axis_first = x.transpose(reaxis_first)
res = np.full(res_shape, scalar_op_identity.item(), dtype=acc_dtype)
for m in range(x.shape[axis]):
elemwise_fn(res, x_axis_first[m], res)
return res
else:
@numba.njit(boundscheck=False)
def careduce_axis(x):
res = scalar_op_identity.item()
for val in x:
res = elemwise_fn(res, val)
return res
return careduce_axis
careduce_fn_name = f"careduce_{get_name_for_object(elemwise_fn)}"
ndim = node.inputs[0].ndim
careduce_axes_fns = ()
to_reduce = reversed(sorted(axes))
careduce_lines_src = []
input_name = get_name_for_object(node.inputs[0])
var_name = input_name
for i, axis in enumerate(to_reduce):
careduce_axes_fns += (create_careduce_axis(axis - i, ndim),)
ndim -= 1
last_var_name = var_name
var_name = f"axis_{i}_res"
careduce_lines_src.append(
f"{var_name} = careduce_axes_fns[{i}]({last_var_name})"
)
careduce_assign_lines = indent("\n".join(careduce_lines_src), " " * 4)
careduce_def_src = f"""
def {careduce_fn_name}({input_name}):
{careduce_assign_lines}
return {var_name}
"""
global_env = {"careduce_axes_fns": careduce_axes_fns}
careduce_fn = compile_function_src(careduce_def_src, careduce_fn_name, global_env)
return numba.njit(careduce_fn)
@numba_funcify.register(Composite)
def numba_funcify_Composite(op, node, **kwargs):
numba_impl = numba.njit(numba_funcify(op.fgraph, **kwargs))
......
......@@ -864,3 +864,74 @@ def test_ARange(start, stop, step, dtype):
if not isinstance(i, (SharedVariable, Constant))
],
)
@pytest.mark.parametrize(
"careduce_fn, axis, v",
[
(aet.sum, 0, set_test_value(aet.vector(), np.arange(3, dtype=config.floatX))),
(aet.all, 0, set_test_value(aet.vector(), np.arange(3, dtype=config.floatX))),
(
aet.sum,
0,
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
aet.sum,
(0, 1),
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
aet.sum,
(1, 0),
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
aet.sum,
None,
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
aet.sum,
1,
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(aet.prod, 0, set_test_value(aet.vector(), np.arange(3, dtype=config.floatX))),
(
aet.prod,
0,
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
(
aet.prod,
1,
set_test_value(
aet.matrix(), np.arange(3 * 2, dtype=config.floatX).reshape((3, 2))
),
),
],
)
def test_CAReduce(careduce_fn, axis, v):
g = careduce_fn(v, axis=axis)
g_fg = FunctionGraph(outputs=[g])
compare_numba_and_py(
g_fg,
[
i.tag.test_value
for i in g_fg.inputs
if not isinstance(i, (SharedVariable, Constant))
],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论