提交 61b8d0b6 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix in-place updates performed on scalars in Numba

上级 00c9e1f7
......@@ -319,7 +319,7 @@ def numba_funcify(op, node=None, storage_map=None, **kwargs):
"""Create a Numba compatible function from an Aesara `Op`."""
warnings.warn(
(f"Numba will use object mode to run {op}'s perform method"),
f"Numba will use object mode to run {op}'s perform method",
UserWarning,
)
......@@ -474,20 +474,37 @@ def numba_funcify_Elemwise(op, node, **kwargs):
elemwise_fn_name = elemwise_fn.__name__
if op.inplace_pattern:
input_idx = op.inplace_pattern[0]
sign_obj = inspect.signature(elemwise_fn.py_scalar_func)
input_names = list(sign_obj.parameters.keys())
input_idx = op.inplace_pattern[0]
unique_names = unique_name_generator([elemwise_fn_name, "np"], suffix_sep="_")
input_names = [unique_names(i, force_unique=True) for i in input_names]
updated_input_name = input_names[input_idx]
inplace_global_env = {elemwise_fn_name: elemwise_fn}
inplace_global_env = {elemwise_fn_name: elemwise_fn, "np": np}
inplace_elemwise_fn_name = f"{elemwise_fn_name}_inplace"
input_signature_str = ", ".join(input_names)
inplace_elemwise_src = f"""
if node.inputs[input_idx].ndim > 0:
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name})
"""
"""
else:
# We can't perform in-place updates on Numba scalars, so we need to
# convert them to NumPy scalars.
# TODO: We should really prevent the rewrites from creating
# in-place updates on scalars when the Numba mode is selected (or
# in general?).
inplace_elemwise_src = f"""
def {inplace_elemwise_fn_name}({input_signature_str}):
{updated_input_name}_scalar = np.asarray({updated_input_name})
return {elemwise_fn_name}({input_signature_str + ", " + updated_input_name}_scalar).item()
"""
inplace_elemwise_fn = compile_function_src(
inplace_elemwise_src, inplace_elemwise_fn_name, inplace_global_env
......
......@@ -18,7 +18,7 @@ import aesara.tensor.random.basic as aer
from aesara import config, shared
from aesara.compile.function import function
from aesara.compile.mode import Mode
from aesara.compile.ops import ViewOp
from aesara.compile.ops import ViewOp, deep_copy_op
from aesara.compile.sharedvalue import SharedVariable
from aesara.graph.basic import Apply, Constant
from aesara.graph.fg import FunctionGraph
......@@ -132,8 +132,11 @@ def eval_python_only(fn_inputs, fgraph, inputs):
def inner_vec(*args):
if len(args) > nparams:
# An `out` argument has been specified for an in-place
# operation
out = args[-1]
out[:] = np.vectorize(fn)(*args[:nparams])
out[...] = np.vectorize(fn)(*args[:nparams])
return out
else:
return np.vectorize(fn)(*args)
......@@ -312,13 +315,22 @@ def test_create_numba_signature(v, expected, force_scalar):
lambda a, b: aet.switch(a, b, a),
None,
),
(
[aet.scalar(), aet.scalar()],
[
np.array(1.0, dtype=config.floatX),
np.array(1.0, dtype=config.floatX),
],
lambda x, y: ati.add_inplace(deep_copy_op(x), deep_copy_op(y)),
None,
),
(
[aet.vector(), aet.vector()],
[
rng.standard_normal(100).astype(config.floatX),
rng.standard_normal(100).astype(config.floatX),
],
lambda x, y: ati.add_inplace(x, y),
lambda x, y: ati.add_inplace(deep_copy_op(x), deep_copy_op(y)),
None,
),
(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论