提交 a2eae29c authored 作者: emekaokoli19's avatar emekaokoli19 提交者: Ricardo Vieira

Numba IfElse: respect view flag

上级 79a4bc1e
from copy import deepcopy from copy import deepcopy
from hashlib import sha256 from hashlib import sha256
from textwrap import dedent
import numba import numba
import numpy as np import numpy as np
...@@ -10,6 +11,7 @@ from pytensor.compile.io import In, Out ...@@ -10,6 +11,7 @@ from pytensor.compile.io import In, Out
from pytensor.compile.mode import NUMBA from pytensor.compile.mode import NUMBA
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.ifelse import IfElse from pytensor.ifelse import IfElse
from pytensor.link.numba.cache import compile_numba_function_src
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import ( from pytensor.link.numba.dispatch.basic import (
numba_funcify_and_cache_key, numba_funcify_and_cache_key,
...@@ -106,30 +108,35 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs): ...@@ -106,30 +108,35 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
@register_funcify_default_op_cache_key(IfElse) @register_funcify_default_op_cache_key(IfElse)
def numba_funcify_IfElse(op, **kwargs): def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs n_outs = op.n_outs
as_view = op.as_view
if n_outs > 1: true_names = [f"t{i}" for i in range(n_outs)]
false_names = [f"f{i}" for i in range(n_outs)]
@numba_basic.numba_njit arg_list = ", ".join((*true_names, *false_names))
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
return res
if as_view:
true_returns = ", ".join(true_names)
else: else:
true_returns = ", ".join(f"{name}.copy()" for name in true_names)
# We only ever view (alias) variables from the true branch. False branch variables must always be copied.
false_returns = ", ".join(f"{name}.copy()" for name in false_names)
func_src = dedent(
f"""
def ifelse(cond, {arg_list}):
if cond:
return {true_returns}
else:
return {false_returns}
"""
)
@numba_basic.numba_njit ifelse_func = numba_basic.numba_njit(
def ifelse(cond, *args): compile_numba_function_src(func_src, "ifelse", globals())
if cond: )
res = args[:n_outs]
else:
res = args[n_outs:]
return res[0]
return ifelse cache_version = 1
return ifelse_func, cache_version
@register_funcify_and_cache_key(CheckAndRaise) @register_funcify_and_cache_key(CheckAndRaise)
......
import numpy as np import numpy as np
import pytest import pytest
from pytensor import Mode, OpFromGraph, config, function, ifelse, scan from pytensor import In, Mode, OpFromGraph, Out, config, function, ifelse, scan
from pytensor import tensor as pt from pytensor import tensor as pt
from pytensor.compile import ViewOp from pytensor.compile import ViewOp
from pytensor.graph import vectorize_graph from pytensor.graph import vectorize_graph
from pytensor.ifelse import IfElse
from pytensor.raise_op import assert_op from pytensor.raise_op import assert_op
from pytensor.scalar import Add from pytensor.scalar import Add
from pytensor.scan.op import Scan from pytensor.scan.op import Scan
...@@ -231,3 +232,43 @@ def test_ofg_with_inner_scan_rewrite(): ...@@ -231,3 +232,43 @@ def test_ofg_with_inner_scan_rewrite():
cholesky_op = scan_op.fgraph.outputs[0].owner.op cholesky_op = scan_op.fgraph.outputs[0].owner.op
assert isinstance(cholesky_op, Blockwise) assert isinstance(cholesky_op, Blockwise)
assert isinstance(cholesky_op.core_op, Cholesky) assert isinstance(cholesky_op.core_op, Cholesky)
@pytest.mark.parametrize("as_view", [True, False])
def test_ifelse_single_output(as_view, single_out=True):
x = pt.vector("x")
y = pt.vector("y")
if single_out:
outs = [x]
else:
outs = [x, y]
op = IfElse(as_view=as_view, n_outs=len(outs))
outs = op(x.sum() > 0, *outs, *outs, return_list=True)
fn = function(
[In(x, borrow=True), In(y, borrow=True)],
[Out(out, borrow=True) for out in outs],
mode=Mode("numba", optimizer=None),
accept_inplace=True,
on_unused_input="ignore",
)
# FALSE branch
test_x = np.zeros(3)
test_y = np.ones(5)
res_false = fn(test_x, test_y)
for test_inp, res_out in zip([test_x, test_y], res_false, strict=False):
np.testing.assert_array_equal(test_inp, res_out)
# IfElse only views on the true branch variates
assert res_out is not test_inp
# TRUE branch
test_x = np.ones(3)
res_true = fn(test_x, test_y)
for test_inp, res_out in zip([test_x, test_y], res_true, strict=False):
np.testing.assert_array_equal(test_inp, res_out)
if as_view:
assert res_out is test_inp
else:
assert res_out is not test_inp
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论