提交 61b8d0b6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix in-place updates performed on scalars in Numba

上级 00c9e1f7
...@@ -319,7 +319,7 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs): ...@@ -319,7 +319,7 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`.""" """Create a Numba compatible function from an Aesara `Op`."""
warnings.warn( warnings.warn(
(f"Numba will use object mode to run {op}'s perform method"), f"Numba will use object mode to run {op}'s perform method",
UserWarning, UserWarning,
) )
...@@ -474,20 +474,37 @@ def numba_funcify_Elemwise(op, node, **kwargs): ...@@ -474,20 +474,37 @@ def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn_name = elemwise_fn.__name__ elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern: if op.inplace_pattern:
input_idx = op.inplace_pattern[0]
sign_obj = inspect.signature(elemwise_fn.py_scalar_func) sign_obj = inspect.signature(elemwise_fn.py_scalar_func)
input_names = list(sign_obj.parameters.keys()) input_names = list(sign_obj.parameters.keys())
input_idx = op.inplace_pattern[0] unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_")
input_names = [unique_names(i, force_unique=True) for i in input_names]
updated_input_name = input_names[input_idx] updated_input_name = input_names[input_idx]
inplace_global_env = {elemwise_fn_name: elemwise_fn} inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np}
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace" inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
input_signature_str = ", ".join(input_names) input_signature_str = ", ".join(input_names)
if node.inputs[input_idx].ndim > 0:
inplace_elemwise_src = f""" inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}): def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}) return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
""" """
else:
# We can't perform in-place updates on Numba scalars, so we need to
# convert them to NumPy scalars.
# TODO: We should really prevent the rewrites from creating
# in-place updates on scalars when the Numba mode is selected (or
# in general?).
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn = compile_function_src( inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env
......
...@@ -18,7 +18,7 @@ import aesara.tensor.random.basic as aer ...@@ -18,7 +18,7 @@ import aesara.tensor.random.basic as aer
from aesara import config, shared from aesara import config, shared
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.mode import Mode from aesara.compile.mode import Mode
from aesara.compile.ops import ViewOp from aesara.compile.ops import ViewOp, deep_copy_op
from aesara.compile.sharedvalue import SharedVariable from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Apply, Constant from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
...@@ -132,8 +132,11 @@ def eval_python_only(fn_inputs, fgraph, inputs): ...@@ -132,8 +132,11 @@ def eval_python_only(fn_inputs, fgraph, inputs):
def inner_vec(*args): def inner_vec(*args):
if len(args) > nparams: if len(args) > nparams:
# An `out` argument has been specified for an in-place
# operation
out = args[-1] out = args[-1]
out[:] = np.vectorize(fn)(*args[:nparams]) out[...] = np.vectorize(fn)(*args[:nparams])
return out
else: else:
return np.vectorize(fn)(*args) return np.vectorize(fn)(*args)
...@@ -312,13 +315,22 @@ def test_create_numba_signature(v, expected, force_scalar): ...@@ -312,13 +315,22 @@ def test_create_numba_signature(v, expected, force_scalar):
lambda a, b: aet.switch(a, b, a), lambda a, b: aet.switch(a, b, a),
None, None,
), ),
(
[aet.scalar(), aet.scalar()],
[
np.array(1.0, dtype=config.floatX),
np.array(1.0, dtype=config.floatX),
],
lambda x, y: ati.add_inplace(deep_copy_op(x), deep_copy_op(y)),
None,
),
( (
[aet.vector(), aet.vector()], [aet.vector(), aet.vector()],
[ [
rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX),
rng.standard_normal(100).astype(config.floatX), rng.standard_normal(100).astype(config.floatX),
], ],
lambda x, y: ati.add_inplace(x, y), lambda x, y: ati.add_inplace(deep_copy_op(x), deep_copy_op(y)),
None, None,
), ),
( (
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论