提交 f41bbab0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add inplace functionality to the Numba Elemwise implementation

上级 08f49d16
......@@ -373,7 +373,7 @@ def {scalar_op_fn_name}({input_names}):
signature = create_numba_signature(node, force_scalar=True)
return numba.njit(signature)(scalar_op_fn)
return numba.njit(signature, inline="always")(scalar_op_fn)
@numba_funcify.register(Switch)
......@@ -424,9 +424,13 @@ def numba_funcify_Mul(op, node, **kwargs):
return numba.njit(signature)(nary_mul_fn)
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs)
def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs):
scalar_op_fn = numba_funcify(op.scalar_op, node, inline="always", **kwargs)
if len(node.outputs) > 1:
raise NotImplementedError(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)
if use_signature:
signature = [create_numba_signature(node, force_scalar=True)]
......@@ -441,15 +445,44 @@ def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwarg
unique_names = unique_name_generator(
[elemwise_fn_name, "scalar_op", "scalar_op", "numba_vectorize"], suffix_sep="_"
)
input_names = ", ".join([unique_names(v, force_unique=True) for v in node.inputs])
input_names = [unique_names(v, force_unique=True) for v in node.inputs]
input_signature_str = ", ".join(input_names)
elemwise_src = f"""
@numba_vectorize
def {elemwise_fn_name}({input_names}):
return scalar_op({input_names})
"""
def {elemwise_fn_name}({input_signature_str}):
return scalar_op({input_signature_str})
"""
elemwise_fn = compile_function_src(elemwise_src, elemwise_fn_name, global_env)
return elemwise_fn, input_names
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn, input_names = create_vectorize_func(op, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern:
input_idx = op.inplace_pattern[0]
updated_input_name = input_names[input_idx]
inplace_global_env = {elemwise_fn_name: elemwise_fn}
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
input_signature_str = ", ".join(input_names)
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env
)
return numba.njit(inline="always")(inplace_elemwise_fn)
return elemwise_fn
......@@ -591,7 +624,11 @@ def numba_funcify_CAReduce(op, node, **kwargs):
[tensor(np_acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(np_acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
)
elemwise_fn = numba_funcify_Elemwise(op, dummy_node, use_signature=True, **kwargs)
# TODO: Use `scalar_op_identity`?
elemwise_fn, *_ = create_vectorize_func(
op, dummy_node, use_signature=True, **kwargs
)
input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim
......@@ -965,7 +1002,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
# E No match.
# ...(on this line)...
# E shuffle_shape = res.shape[: len(shuffle)]
@numba.njit
@numba.njit(inline="always")
def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle)
......
......@@ -792,8 +792,10 @@ second dimension
if nout == 1:
variables = [variables]
i = 0
for variable, storage, nout in zip(variables, output_storage, node.outputs):
for i, (variable, storage, nout) in enumerate(
zip(variables, output_storage, node.outputs)
):
if getattr(variable, "dtype", "") == "object":
# Since numpy 1.6, function created with numpy.frompyfunc
# always return an ndarray with dtype object
......@@ -803,6 +805,7 @@ second dimension
odat = inputs[self.inplace_pattern[i]]
odat[...] = variable
storage[0] = odat
# Sometimes NumPy return a Python type.
# Some Aesara op return a different dtype like floor, ceil,
# trunc, eq, ...
......@@ -821,7 +824,6 @@ second dimension
storage[0] = variable.copy()
else:
storage[0] = variable
i += 1
def infer_shape(self, fgraph, node, i_shapes):
rval = []
......
......@@ -11,6 +11,7 @@ import aesara.scalar.basic as aesb
import aesara.scalar.math as aesm
import aesara.tensor as aet
import aesara.tensor.basic as aetb
import aesara.tensor.inplace as ati
import aesara.tensor.math as aem
import aesara.tensor.nnet.basic as nnetb
from aesara import config, shared
......@@ -60,12 +61,19 @@ class MySingleOut(Op):
class MyMultiOut(Op):
nin = 2
nout = 2
def make_node(self, a, b):
return Apply(self, [a, b], [a.type(), b.type()])
def impl(self, a, b):
res1 = 2 * a
res2 = 2 * b
return [res1, res2]
def perform(self, node, inputs, outputs):
res1 = 2 * inputs[0]
res2 = 2 * inputs[1]
res1, res2 = self.impl(inputs[0], inputs[1])
outputs[0][0] = res1
outputs[1][0] = res2
......@@ -273,29 +281,58 @@ def test_create_numba_signature(v, expected, force_scalar):
@pytest.mark.parametrize(
"inputs, input_vals, output_fn",
"inputs, input_vals, output_fn, exc",
[
(
[aet.vector()],
[rng.randn(100).astype(config.floatX)],
lambda x: aet.sigmoid(x),
None,
),
(
[aet.vector() for i in range(4)],
[rng.randn(100).astype(config.floatX) for i in range(4)],
lambda x, y, x1, y1: (x + y) * (x1 + y1) * y,
None,
),
(
# This also tests the use of repeated arguments
[aet.matrix(), aet.scalar()],
[rng.normal(size=(2, 2)).astype(config.floatX), 0.0],
lambda a, b: aet.switch(a, b, a),
None,
),
(
[aet.vector(), aet.vector()],
[
rng.randn(100).astype(config.floatX),
rng.randn(100).astype(config.floatX),
],
lambda x, y: ati.add_inplace(x, y),
None,
),
(
[aet.vector(), aet.vector()],
[
rng.randn(100).astype(config.floatX),
rng.randn(100).astype(config.floatX),
],
lambda x, y: Elemwise(MyMultiOut())(x, y),
NotImplementedError,
),
],
)
def test_Elemwise(inputs, input_vals, output_fn):
out_fg = FunctionGraph(outputs=[output_fn(*inputs)])
compare_numba_and_py(out_fg, input_vals)
def test_Elemwise(inputs, input_vals, output_fn, exc):
outputs = output_fn(*inputs)
out_fg = FunctionGraph(
outputs=[outputs] if not isinstance(outputs, list) else outputs
)
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py(out_fg, input_vals)
@pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论