提交 41868191 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba OpFromGraph: Prevent alias of outputs

上级 57b344e7
......@@ -5,8 +5,8 @@ import numba
import numpy as np
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.io import In
from pytensor.compile.function.types import add_supervisor_to_fgraph, insert_deepcopy
from pytensor.compile.io import In, Out
from pytensor.compile.mode import NUMBA
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.ifelse import IfElse
......@@ -56,14 +56,17 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
fgraph = op.fgraph
input_specs = [In(x, borrow=True, mutable=False) for x in fgraph.inputs]
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
input_specs=input_specs,
accept_inplace=True,
)
NUMBA.optimizer(fgraph)
output_specs = [Out(o, borrow=False) for o in fgraph.outputs]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
op.fgraph, squeeze_output=True, **kwargs
fgraph, squeeze_output=True, **kwargs
)
if fgraph_cache_key is None:
......
......@@ -5,6 +5,9 @@ from pytensor import OpFromGraph, config, function, ifelse
from pytensor import tensor as pt
from pytensor.compile import ViewOp
from pytensor.raise_op import assert_op
from pytensor.scalar import Add
from pytensor.tensor import matrix
from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py
......@@ -146,6 +149,28 @@ def test_ofg_inner_inplace():
np.testing.assert_allclose(res1, [1, np.e, np.e])
def test_ofg_aliased_outputs():
x = matrix("x")
# Create multiple views of x
outs = OpFromGraph([x], [x, x.T, x[::-1]])(x)
# Add one to each x, which when inplace shouldn't propagate across outputs
bumped_outs = [o + 1 for o in outs]
fn = function([x], bumped_outs, mode="NUMBA")
fn.dprint(print_destroy_map=True)
# Check our outputs are indeed inplace adds
assert all(
(
isinstance(o.owner.op, Elemwise)
and isinstance(o.owner.op.scalar_op, Add)
and o.owner.op.destroy_map
)
for o in fn.maker.fgraph.outputs
)
x_test = np.zeros((2, 2))
for res in fn(x_test):
np.testing.assert_allclose(res, np.ones((2, 2)))
def test_check_and_raise():
x = pt.vector()
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论