提交 f41bbab0 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add inplace functionality to the Numba Elemwise implementation

上级 08f49d16
...@@ -373,7 +373,7 @@ def {scalar_op_fn_name}({input_names}): ...@@ -373,7 +373,7 @@ def {scalar_op_fn_name}({input_names}):
signature = create_numba_signature(node, force_scalar=True) signature = create_numba_signature(node, force_scalar=True)
return numba.njit(signature)(scalar_op_fn) return numba.njit(signature, inline="always")(scalar_op_fn)
@numba_funcify.register(Switch) @numba_funcify.register(Switch)
...@@ -424,9 +424,13 @@ def numba_funcify_Mul(op, node, **kwargs): ...@@ -424,9 +424,13 @@ def numba_funcify_Mul(op, node, **kwargs):
return numba.njit(signature)(nary_mul_fn) return numba.njit(signature)(nary_mul_fn)
@numba_funcify.register(Elemwise) def create_vectorize_func(op, node, use_signature=False, identity=None, **kwargs):
def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwargs): scalar_op_fn = numba_funcify(op.scalar_op, node, inline="always", **kwargs)
scalar_op_fn = numba_funcify(op.scalar_op, node, **kwargs)
if len(node.outputs) > 1:
raise NotImplementedError(
"Multi-output Elemwise Ops are not supported by the Numba backend"
)
if use_signature: if use_signature:
signature = [create_numba_signature(node, force_scalar=True)] signature = [create_numba_signature(node, force_scalar=True)]
...@@ -441,15 +445,44 @@ def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwarg ...@@ -441,15 +445,44 @@ def numba_funcify_Elemwise(op, node, use_signature=False, identity=None, **kwarg
unique_names = unique_name_generator( unique_names = unique_name_generator(
[elemwise_fn_name, "scalar_op", "scalar_op", "numba_vectorize"], suffix_sep="_" [elemwise_fn_name, "scalar_op", "scalar_op", "numba_vectorize"], suffix_sep="_"
) )
input_names = ", ".join([unique_names(v, force_unique=True) for v in node.inputs]) input_names = [unique_names(v, force_unique=True) for v in node.inputs]
input_signature_str = ", ".join(input_names)
elemwise_src = f""" elemwise_src = f"""
@numba_vectorize @numba_vectorize
def {elemwise_fn_name}({input_names}): def {elemwise_fn_name}({input_signature_str}):
return scalar_op({input_names}) return scalar_op({input_signature_str})
""" """
elemwise_fn = compile_function_src(elemwise_src, elemwise_fn_name, global_env) elemwise_fn = compile_function_src(elemwise_src, elemwise_fn_name, global_env)
return elemwise_fn, input_names
@numba_funcify.register(Elemwise)
def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn, input_names = create_vectorize_func(op, node, use_signature=False)
elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern:
input_idx = op.inplace_pattern[0]
updated_input_name = input_names[input_idx]
inplace_global_env = {elemwise_fn_name: elemwise_fn}
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
input_signature_str = ", ".join(input_names)
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env
)
return numba.njit(inline="always")(inplace_elemwise_fn)
return elemwise_fn return elemwise_fn
...@@ -591,7 +624,11 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -591,7 +624,11 @@ def numba_funcify_CAReduce(op, node, **kwargs):
[tensor(np_acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])], [tensor(np_acc_dtype, [False]) for i in range(scalar_nfunc_spec[1])],
[tensor(np_acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])], [tensor(np_acc_dtype, [False]) for o in range(scalar_nfunc_spec[2])],
) )
elemwise_fn = numba_funcify_Elemwise(op, dummy_node, use_signature=True, **kwargs)
# TODO: Use `scalar_op_identity`?
elemwise_fn, *_ = create_vectorize_func(
op, dummy_node, use_signature=True, **kwargs
)
input_name = get_name_for_object(node.inputs[0]) input_name = get_name_for_object(node.inputs[0])
ndim = node.inputs[0].ndim ndim = node.inputs[0].ndim
...@@ -965,7 +1002,7 @@ def numba_funcify_DimShuffle(op, **kwargs): ...@@ -965,7 +1002,7 @@ def numba_funcify_DimShuffle(op, **kwargs):
# E No match. # E No match.
# ...(on this line)... # ...(on this line)...
# E shuffle_shape = res.shape[: len(shuffle)] # E shuffle_shape = res.shape[: len(shuffle)]
@numba.njit @numba.njit(inline="always")
def dimshuffle(x): def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle) return dimshuffle_inner(np.asarray(x), shuffle)
......
...@@ -792,8 +792,10 @@ second dimension ...@@ -792,8 +792,10 @@ second dimension
if nout == 1: if nout == 1:
variables = [variables] variables = [variables]
i = 0
for variable, storage, nout in zip(variables, output_storage, node.outputs): for i, (variable, storage, nout) in enumerate(
zip(variables, output_storage, node.outputs)
):
if getattr(variable, "dtype", "") == "object": if getattr(variable, "dtype", "") == "object":
# Since numpy 1.6, function created with numpy.frompyfunc # Since numpy 1.6, function created with numpy.frompyfunc
# always return an ndarray with dtype object # always return an ndarray with dtype object
...@@ -803,6 +805,7 @@ second dimension ...@@ -803,6 +805,7 @@ second dimension
odat = inputs[self.inplace_pattern[i]] odat = inputs[self.inplace_pattern[i]]
odat[...] = variable odat[...] = variable
storage[0] = odat storage[0] = odat
# Sometimes NumPy return a Python type. # Sometimes NumPy return a Python type.
# Some Aesara op return a different dtype like floor, ceil, # Some Aesara op return a different dtype like floor, ceil,
# trunc, eq, ... # trunc, eq, ...
...@@ -821,7 +824,6 @@ second dimension ...@@ -821,7 +824,6 @@ second dimension
storage[0] = variable.copy() storage[0] = variable.copy()
else: else:
storage[0] = variable storage[0] = variable
i += 1
def infer_shape(self, fgraph, node, i_shapes): def infer_shape(self, fgraph, node, i_shapes):
rval = [] rval = []
......
...@@ -11,6 +11,7 @@ import aesara.scalar.basic as aesb ...@@ -11,6 +11,7 @@ import aesara.scalar.basic as aesb
import aesara.scalar.math as aesm import aesara.scalar.math as aesm
import aesara.tensor as aet import aesara.tensor as aet
import aesara.tensor.basic as aetb import aesara.tensor.basic as aetb
import aesara.tensor.inplace as ati
import aesara.tensor.math as aem import aesara.tensor.math as aem
import aesara.tensor.nnet.basic as nnetb import aesara.tensor.nnet.basic as nnetb
from aesara import config, shared from aesara import config, shared
...@@ -60,12 +61,19 @@ class MySingleOut(Op): ...@@ -60,12 +61,19 @@ class MySingleOut(Op):
class MyMultiOut(Op): class MyMultiOut(Op):
nin = 2
nout = 2
def make_node(self, a, b): def make_node(self, a, b):
return Apply(self, [a, b], [a.type(), b.type()]) return Apply(self, [a, b], [a.type(), b.type()])
def impl(self, a, b):
res1 = 2 * a
res2 = 2 * b
return [res1, res2]
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
res1 = 2 * inputs[0] res1, res2 = self.impl(inputs[0], inputs[1])
res2 = 2 * inputs[1]
outputs[0][0] = res1 outputs[0][0] = res1
outputs[1][0] = res2 outputs[1][0] = res2
...@@ -273,29 +281,58 @@ def test_create_numba_signature(v, expected, force_scalar): ...@@ -273,29 +281,58 @@ def test_create_numba_signature(v, expected, force_scalar):
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inputs, input_vals, output_fn", "inputs, input_vals, output_fn, exc",
[ [
( (
[aet.vector()], [aet.vector()],
[rng.randn(100).astype(config.floatX)], [rng.randn(100).astype(config.floatX)],
lambda x: aet.sigmoid(x), lambda x: aet.sigmoid(x),
None,
), ),
( (
[aet.vector() for i in range(4)], [aet.vector() for i in range(4)],
[rng.randn(100).astype(config.floatX) for i in range(4)], [rng.randn(100).astype(config.floatX) for i in range(4)],
lambda x, y, x1, y1: (x + y) * (x1 + y1) * y, lambda x, y, x1, y1: (x + y) * (x1 + y1) * y,
None,
), ),
( (
# This also tests the use of repeated arguments # This also tests the use of repeated arguments
[aet.matrix(), aet.scalar()], [aet.matrix(), aet.scalar()],
[rng.normal(size=(2, 2)).astype(config.floatX), 0.0], [rng.normal(size=(2, 2)).astype(config.floatX), 0.0],
lambda a, b: aet.switch(a, b, a), lambda a, b: aet.switch(a, b, a),
None,
),
(
[aet.vector(), aet.vector()],
[
rng.randn(100).astype(config.floatX),
rng.randn(100).astype(config.floatX),
],
lambda x, y: ati.add_inplace(x, y),
None,
),
(
[aet.vector(), aet.vector()],
[
rng.randn(100).astype(config.floatX),
rng.randn(100).astype(config.floatX),
],
lambda x, y: Elemwise(MyMultiOut())(x, y),
NotImplementedError,
), ),
], ],
) )
def test_Elemwise(inputs, input_vals, output_fn): def test_Elemwise(inputs, input_vals, output_fn, exc):
out_fg = FunctionGraph(outputs=[output_fn(*inputs)])
compare_numba_and_py(out_fg, input_vals) outputs = output_fn(*inputs)
out_fg = FunctionGraph(
outputs=[outputs] if not isinstance(outputs, list) else outputs
)
cm = contextlib.suppress() if exc is None else pytest.raises(exc)
with cm:
compare_numba_and_py(out_fg, input_vals)
@pytest.mark.parametrize( @pytest.mark.parametrize(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论