提交 d39ad599 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move compile ops to their own dispatch file

上级 0f2afc48
...@@ -3,6 +3,7 @@ from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify ...@@ -3,6 +3,7 @@ from pytensor.link.numba.dispatch.basic import numba_funcify, numba_typify
# Load dispatch specializations # Load dispatch specializations
import pytensor.link.numba.dispatch.blockwise import pytensor.link.numba.dispatch.blockwise
import pytensor.link.numba.dispatch.compile_ops
import pytensor.link.numba.dispatch.elemwise import pytensor.link.numba.dispatch.elemwise
import pytensor.link.numba.dispatch.extra_ops import pytensor.link.numba.dispatch.extra_ops
import pytensor.link.numba.dispatch.nlinalg import pytensor.link.numba.dispatch.nlinalg
......
...@@ -6,15 +6,10 @@ import numpy as np ...@@ -6,15 +6,10 @@ import numpy as np
from numba.core.errors import NumbaWarning from numba.core.errors import NumbaWarning
from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401 from numba.cpython.unsafe.tuple import tuple_setitem # noqa: F401
from pytensor import In, config from pytensor import config
from pytensor.compile import NUMBA
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.graph.basic import Apply from pytensor.graph.basic import Apply
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.utils import ( from pytensor.link.utils import (
fgraph_to_python, fgraph_to_python,
...@@ -280,90 +275,3 @@ def numba_funcify_FunctionGraph( ...@@ -280,90 +275,3 @@ def numba_funcify_FunctionGraph(
fgraph_name=fgraph_name, fgraph_name=fgraph_name,
**kwargs, **kwargs,
) )
@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
accept_inplace=True,
)
NUMBA.optimizer(fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1:
@numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]
else:
@numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)
return opfromgraph
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
@numba_njit
def identity(x):
return x
return identity
@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
if isinstance(node.inputs[0].type, TensorType):
@numba_njit
def deepcopy(x):
return np.copy(x)
else:
@numba_njit
def deepcopy(x):
return x
return deepcopy
@numba_funcify.register(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
if n_outs > 1:
@numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
return res
else:
@numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
return res[0]
return ifelse
import numpy as np
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function.types import add_supervisor_to_fgraph
from pytensor.compile.io import In
from pytensor.compile.mode import NUMBA
from pytensor.compile.ops import DeepCopyOp, TypeCastingOp
from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.dispatch.basic import (
numba_funcify,
numba_njit,
)
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor.type import TensorType
@numba_funcify.register(OpFromGraph)
def numba_funcify_OpFromGraph(op, node=None, **kwargs):
_ = kwargs.pop("storage_map", None)
# Apply inner rewrites
# TODO: Not sure this is the right place to do this, should we have a rewrite that
# explicitly triggers the optimization of the inner graphs of OpFromGraph?
# The C-code defers it to the make_thunk phase
fgraph = op.fgraph
add_supervisor_to_fgraph(
fgraph=fgraph,
input_specs=[In(x, borrow=True, mutable=False) for x in fgraph.inputs],
accept_inplace=True,
)
NUMBA.optimizer(fgraph)
fgraph_fn = numba_njit(numba_funcify(op.fgraph, **kwargs))
if len(op.fgraph.outputs) == 1:
@numba_basic.numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)[0]
else:
@numba_basic.numba_njit
def opfromgraph(*inputs):
return fgraph_fn(*inputs)
return opfromgraph
@numba_funcify.register(TypeCastingOp)
def numba_funcify_type_casting(op, **kwargs):
@numba_basic.numba_njit
def identity(x):
return x
return identity
@numba_funcify.register(DeepCopyOp)
def numba_funcify_DeepCopyOp(op, node, **kwargs):
if isinstance(node.inputs[0].type, TensorType):
@numba_basic.numba_njit
def deepcopy(x):
return np.copy(x)
else:
@numba_basic.numba_njit
def deepcopy(x):
return x
return deepcopy
@numba_funcify.register(IfElse)
def numba_funcify_IfElse(op, **kwargs):
n_outs = op.n_outs
if n_outs > 1:
@numba_basic.numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
return res
else:
@numba_basic.numba_njit
def ifelse(cond, *args):
if cond:
res = args[:n_outs]
else:
res = args[n_outs:]
return res[0]
return ifelse
@numba_funcify.register(CheckAndRaise)
def numba_funcify_CheckAndRaise(op, node, **kwargs):
error = op.exc_type
msg = op.msg
@numba_basic.numba_njit
def check_and_raise(x, *conditions):
for cond in conditions:
if not cond:
raise error(msg)
return x
return check_and_raise
...@@ -11,7 +11,6 @@ from pytensor.link.numba.dispatch.basic import ( ...@@ -11,7 +11,6 @@ from pytensor.link.numba.dispatch.basic import (
get_numba_type, get_numba_type,
numba_funcify, numba_funcify,
) )
from pytensor.raise_op import CheckAndRaise
from pytensor.tensor import TensorVariable from pytensor.tensor import TensorVariable
from pytensor.tensor.extra_ops import ( from pytensor.tensor.extra_ops import (
Bartlett, Bartlett,
...@@ -325,18 +324,3 @@ def numba_funcify_Searchsorted(op, node, **kwargs): ...@@ -325,18 +324,3 @@ def numba_funcify_Searchsorted(op, node, **kwargs):
return np.searchsorted(a, v, side) return np.searchsorted(a, v, side)
return searchsorted return searchsorted
@numba_funcify.register(CheckAndRaise)
def numba_funcify_CheckAndRaise(op, node, **kwargs):
error = op.exc_type
msg = op.msg
@numba_basic.numba_njit
def check_and_raise(x, *conditions):
for cond in conditions:
if not cond:
raise error(msg)
return x
return check_and_raise
...@@ -15,15 +15,12 @@ numba = pytest.importorskip("numba") ...@@ -15,15 +15,12 @@ numba = pytest.importorskip("numba")
import pytensor.scalar as ps import pytensor.scalar as ps
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config, shared from pytensor import config, shared
from pytensor.compile.builders import OpFromGraph
from pytensor.compile.function import function from pytensor.compile.function import function
from pytensor.compile.mode import Mode from pytensor.compile.mode import Mode
from pytensor.compile.ops import ViewOp
from pytensor.graph.basic import Apply, Variable from pytensor.graph.basic import Apply, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.graph.rewriting.db import RewriteDatabaseQuery from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from pytensor.graph.type import Type from pytensor.graph.type import Type
from pytensor.ifelse import ifelse
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
from pytensor.link.numba.linker import NumbaLinker from pytensor.link.numba.linker import NumbaLinker
from pytensor.scalar.basic import ScalarOp, as_scalar from pytensor.scalar.basic import ScalarOp, as_scalar
...@@ -320,18 +317,6 @@ def test_create_numba_signature(v, expected, force_scalar): ...@@ -320,18 +317,6 @@ def test_create_numba_signature(v, expected, force_scalar):
assert res == expected assert res == expected
def test_ViewOp():
v = pt.vector()
v_test_value = np.arange(4, dtype=config.floatX)
g = ViewOp()(v)
compare_numba_and_py(
[v],
[g],
[v_test_value],
)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"inputs, op", "inputs, op",
[ [
...@@ -406,84 +391,6 @@ def test_shared_updates(): ...@@ -406,84 +391,6 @@ def test_shared_updates():
assert a.get_value() == 7 assert a.get_value() == 7
# We were seeing some weird results in CI where the following two almost
# sign-swapped results were being return from Numba and Python, respectively.
# The issue might be related to https://github.com/numba/numba/issues/4519.
# Regardless, I was not able to reproduce anything like it locally after
# extensive testing.
x = np.array(
[
[-0.60407637, -0.71177603, -0.35842241],
[-0.07735968, 0.50000561, -0.86256007],
[-0.7931628, 0.49332471, 0.35710434],
],
dtype=np.float64,
)
y = np.array(
[
[0.60407637, 0.71177603, -0.35842241],
[0.07735968, -0.50000561, -0.86256007],
[0.7931628, -0.49332471, 0.35710434],
],
dtype=np.float64,
)
@pytest.mark.parametrize(
"inputs, cond_fn, true_vals, false_vals",
[
([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]),
(
[(pt.dscalar(), np.array(0.2, dtype=np.float64))],
lambda x: x < 0.5,
np.r_[1, 2, 3],
np.r_[-1, -2, -3],
),
(
[
(pt.dscalar(), np.array(0.3, dtype=np.float64)),
(pt.dscalar(), np.array(0.5, dtype=np.float64)),
],
lambda x, y: x > y,
x,
y,
),
(
[
(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: pt.all(x > y),
x,
y,
),
(
[
(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: pt.all(x > y),
[x, 2 * x],
[y, 3 * y],
),
(
[
(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
],
lambda x, y: pt.all(x > y),
[x, 2 * x],
[y, 3 * y],
),
],
)
def test_IfElse(inputs, cond_fn, true_vals, false_vals):
inputs, test_values = zip(*inputs, strict=True) if inputs else ([], [])
out = ifelse(cond_fn(*inputs), true_vals, false_vals)
compare_numba_and_py(inputs, out, test_values)
def test_config_options_fastmath(): def test_config_options_fastmath():
x = pt.dvector() x = pt.dvector()
...@@ -524,54 +431,6 @@ def test_scalar_return_value_conversion(): ...@@ -524,54 +431,6 @@ def test_scalar_return_value_conversion():
assert isinstance(x_fn(1.0), np.ndarray) assert isinstance(x_fn(1.0), np.ndarray)
def test_OpFromGraph():
x, y, z = pt.matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)
o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_numba_and_py([x, y, z], [out], [xv, yv, zv])
@pytest.mark.filterwarnings("error")
def test_ofg_inner_inplace():
x = pt.vector("x")
set0 = x[0].set(1) # SetSubtensor should not inplace on x
exp_x = pt.exp(x)
set1 = exp_x[0].set(1) # SetSubtensor should inplace on exp_x
ofg0 = OpFromGraph([x], [set0])
ofg1 = OpFromGraph([x], [set1])
y, z = pt.vectors("y", "z")
fn = function([y, z], [ofg0(y), ofg1(z)], mode="NUMBA")
fn_ofg0 = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(fn_ofg0, OpFromGraph)
fn_set0 = fn_ofg0.fgraph.outputs[0]
assert fn_set0.owner.op.destroy_map == {}
fn_ofg1 = fn.maker.fgraph.outputs[1].owner.op
assert isinstance(fn_ofg1, OpFromGraph)
fn_set1 = fn_ofg1.fgraph.outputs[0]
assert fn_set1.owner.op.destroy_map == {0: [0]}
x_test = np.array([0, 1, 1], dtype=config.floatX)
y_test = np.array([0, 1, 1], dtype=config.floatX)
res0, res1 = fn(x_test, y_test)
# Check inputs were not mutated
np.testing.assert_allclose(x_test, [0, 1, 1])
np.testing.assert_allclose(y_test, [0, 1, 1])
# Check outputs are correct
np.testing.assert_allclose(res0, [1, 1, 1])
np.testing.assert_allclose(res1, [1, np.e, np.e])
@pytest.mark.filterwarnings("error") @pytest.mark.filterwarnings("error")
def test_cache_warning_suppressed(): def test_cache_warning_suppressed():
x = pt.vector("x", shape=(5,), dtype="float64") x = pt.vector("x", shape=(5,), dtype="float64")
......
import numpy as np
import pytest
from pytensor import OpFromGraph, config, function, ifelse
from pytensor import tensor as pt
from pytensor.compile import ViewOp
from pytensor.raise_op import assert_op
from tests.link.numba.test_basic import compare_numba_and_py
def test_ViewOp():
v = pt.vector()
v_test_value = np.arange(4, dtype=config.floatX)
g = ViewOp()(v)
compare_numba_and_py(
[v],
[g],
[v_test_value],
)
# We were seeing some weird results in CI where the following two almost
# sign-swapped results were being return from Numba and Python, respectively.
# The issue might be related to https://github.com/numba/numba/issues/4519.
# Regardless, I was not able to reproduce anything like it locally after
# extensive testing.
x = np.array(
[
[-0.60407637, -0.71177603, -0.35842241],
[-0.07735968, 0.50000561, -0.86256007],
[-0.7931628, 0.49332471, 0.35710434],
],
dtype=np.float64,
)
y = np.array(
[
[0.60407637, 0.71177603, -0.35842241],
[0.07735968, -0.50000561, -0.86256007],
[0.7931628, -0.49332471, 0.35710434],
],
dtype=np.float64,
)
@pytest.mark.parametrize(
"inputs, cond_fn, true_vals, false_vals",
[
([], lambda: np.array(True), np.r_[1, 2, 3], np.r_[-1, -2, -3]),
(
[(pt.dscalar(), np.array(0.2, dtype=np.float64))],
lambda x: x < 0.5,
np.r_[1, 2, 3],
np.r_[-1, -2, -3],
),
(
[
(pt.dscalar(), np.array(0.3, dtype=np.float64)),
(pt.dscalar(), np.array(0.5, dtype=np.float64)),
],
lambda x, y: x > y,
x,
y,
),
(
[
(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: pt.all(x > y),
x,
y,
),
(
[
(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
],
lambda x, y: pt.all(x > y),
[x, 2 * x],
[y, 3 * y],
),
(
[
(pt.dvector(), np.array([0.5, 0.9], dtype=np.float64)),
(pt.dvector(), np.array([0.3, 0.1], dtype=np.float64)),
],
lambda x, y: pt.all(x > y),
[x, 2 * x],
[y, 3 * y],
),
],
)
def test_IfElse(inputs, cond_fn, true_vals, false_vals):
inputs, test_values = zip(*inputs, strict=True) if inputs else ([], [])
out = ifelse(cond_fn(*inputs), true_vals, false_vals)
compare_numba_and_py(inputs, out, test_values)
def test_OpFromGraph():
x, y, z = pt.matrices("xyz")
ofg_1 = OpFromGraph([x, y], [x + y], inline=False)
ofg_2 = OpFromGraph([x, y], [x * y, x - y], inline=False)
o1, o2 = ofg_2(y, z)
out = ofg_1(x, o1) + o2
xv = np.ones((2, 2), dtype=config.floatX)
yv = np.ones((2, 2), dtype=config.floatX) * 3
zv = np.ones((2, 2), dtype=config.floatX) * 5
compare_numba_and_py([x, y, z], [out], [xv, yv, zv])
@pytest.mark.filterwarnings("error")
def test_ofg_inner_inplace():
x = pt.vector("x")
set0 = x[0].set(1) # SetSubtensor should not inplace on x
exp_x = pt.exp(x)
set1 = exp_x[0].set(1) # SetSubtensor should inplace on exp_x
ofg0 = OpFromGraph([x], [set0])
ofg1 = OpFromGraph([x], [set1])
y, z = pt.vectors("y", "z")
fn = function([y, z], [ofg0(y), ofg1(z)], mode="NUMBA")
fn_ofg0 = fn.maker.fgraph.outputs[0].owner.op
assert isinstance(fn_ofg0, OpFromGraph)
fn_set0 = fn_ofg0.fgraph.outputs[0]
assert fn_set0.owner.op.destroy_map == {}
fn_ofg1 = fn.maker.fgraph.outputs[1].owner.op
assert isinstance(fn_ofg1, OpFromGraph)
fn_set1 = fn_ofg1.fgraph.outputs[0]
assert fn_set1.owner.op.destroy_map == {0: [0]}
x_test = np.array([0, 1, 1], dtype=config.floatX)
y_test = np.array([0, 1, 1], dtype=config.floatX)
res0, res1 = fn(x_test, y_test)
# Check inputs were not mutated
np.testing.assert_allclose(x_test, [0, 1, 1])
np.testing.assert_allclose(y_test, [0, 1, 1])
# Check outputs are correct
np.testing.assert_allclose(res0, [1, 1, 1])
np.testing.assert_allclose(res1, [1, np.e, np.e])
def test_check_and_raise():
x = pt.vector()
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x.sum(), np.array(True))
compare_numba_and_py([x], out, [x_test_value])
...@@ -5,7 +5,6 @@ import pytest ...@@ -5,7 +5,6 @@ import pytest
import pytensor.tensor as pt import pytensor.tensor as pt
from pytensor import config from pytensor import config
from pytensor.raise_op import assert_op
from pytensor.tensor import extra_ops from pytensor.tensor import extra_ops
from tests.link.numba.test_basic import compare_numba_and_py from tests.link.numba.test_basic import compare_numba_and_py
...@@ -384,12 +383,3 @@ def test_Searchsorted(a, v, side, sorter, exc): ...@@ -384,12 +383,3 @@ def test_Searchsorted(a, v, side, sorter, exc):
g, g,
[test_a, test_v] if sorter is None else [test_a, test_v, test_sorter], [test_a, test_v] if sorter is None else [test_a, test_v, test_sorter],
) )
def test_check_and_raise():
x = pt.vector()
x_test_value = np.array([1.0, 2.0], dtype=config.floatX)
out = assert_op(x.sum(), np.array(True))
compare_numba_and_py([x], out, [x_test_value])
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论