提交 41868191 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Numba OpFromGraph: Prevent alias of outputs

上级 57b344e7
...@@ -5,8 +5,8 @@ import numba ...@@ -5,8 +5,8 @@ import numba
import numpy as np import numpy as np
from pytensor.compile.builders import OpFromGraph from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph from pytensor.compile.function.types import add_supervisor_to_fgraph, insert_deepcopy
from pytensor.compile.io import In from pytensor.compile.io import In, Out
from pytensor.compile.mode import NUMBA from pytensor.compile.mode import NUMBA
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.ifelse import IfElse from pytensor.ifelse import IfElse
...@@ -56,14 +56,17 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs): ...@@ -56,14 +56,17 @@ def numba_funcify_OpFromGraph(op, node=None, **kwargs):
# explicitly triggers the optimization of the inner graphs of OpFromGraph? # explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase # The C-code defers it to the make_thunk phase
fgraph = op.fgraph fgraph = op.fgraph
input_specs = [In(x, borrow=True, mutable=False) for x in fgraph.inputs]
add_supervisor_to_fgraph( add_supervisor_to_fgraph(
fgraph=fgraph, fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs], input_specs=input_specs,
accept_inplace=True, accept_inplace=True,
) )
NUMBA.optimizer(fgraph) NUMBA.optimizer(fgraph)
output_specs = [Out(o, borrow=False) for o in fgraph.outputs]
insert_deepcopy(fgraph, wrapped_inputs=input_specs, wrapped_outputs=output_specs)
fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key( fgraph_fn, fgraph_cache_key = numba_funcify_and_cache_key(
op.fgraph, squeeze_output=True, **kwargs fgraph, squeeze_output=True, **kwargs
) )
if fgraph_cache_key is None: if fgraph_cache_key is None:
......
...@@ -5,6 +5,9 @@ from pytensor import OpFromGraph, config, function, ifelse ...@@ -5,6 +5,9 @@ from pytensor import OpFromGraph, config, function, ifelse
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile import ViewOp from pytensor.compile import ViewOp
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.scalar import Add
from pytensor.tensor import matrix
from pytensor.tensor.elemwise import Elemwise
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
...@@ -146,6 +149,28 @@ def test_ofg_inner_inplace(): ...@@ -146,6 +149,28 @@ def test_ofg_inner_inplace():
np.testing.assert_allclose(res1, [1, np.e, np.e]) np.testing.assert_allclose(res1, [1, np.e, np.e])
def test_ofg_aliased_outputs():
x = matrix("x")
# Create multiple views of x
outs = OpFromGraph([x], [x, x.T, x[::-1]])(x)
# Add one to each x, which when inplace shouldn't propagate across outputs
bumped_outs = [o + 1 for o in outs]
fn = function([x], bumped_outs, mode="NUMBA")
fn.dprint(print_destroy_map=True)
# Check our outputs are indeed inplace adds
assert all(
(
isinstance(o.owner.op, Elemwise)
and isinstance(o.owner.op.scalar_op, Add)
and o.owner.op.destroy_map
)
for o in fn.maker.fgraph.outputs
)
x_test = np.zeros((2, 2))
for res in fn(x_test):
np.testing.assert_allclose(res, np.ones((2, 2)))
def test_check_and_raise(): def test_check_and_raise():
x = pt.vector() x = pt.vector()
x_test_value = np.array([1.0, 2.0], dtype=config.floatX) x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论