提交 c90ee4ba authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Move shape Ops dispatchers to their own file

上级 9da643e8
......@@ -9,6 +9,7 @@ import pytensor.link.numba.dispatch.nlinalg
import pytensor.link.numba.dispatch.random
import pytensor.link.numba.dispatch.scan
import pytensor.link.numba.dispatch.scalar
import pytensor.link.numba.dispatch.shape
import pytensor.link.numba.dispatch.signal
import pytensor.link.numba.dispatch.slinalg
import pytensor.link.numba.dispatch.sparse
......
import warnings
from copy import copy
from functools import singledispatch
from textwrap import dedent
import numba
import numba.np.unsafe.ndarray as numba_ndarray
import numpy as np
from numba import types
from numba.core.errors import NumbaWarning, TypingError
......@@ -22,7 +20,6 @@ from pytensor.graph.type import Type
from pytensor.ifelse import IfElse
from pytensor.link.numba.dispatch.sparse import CSCMatrixType, CSRMatrixType
from pytensor.link.utils import (
compile_function_src,
fgraph_to_python,
)
from pytensor.scalar.basic import ScalarType
......@@ -30,10 +27,8 @@ from pytensor.sparse import SparseTensorType
from pytensor.tensor.basic import Nonzero
from pytensor.tensor.blas import BatchedDot
from pytensor.tensor.math import Dot
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.sort import ArgSortOp, SortOp
from pytensor.tensor.type import TensorType
from pytensor.tensor.type_other import NoneConst
def numba_njit(*args, fastmath=None, **kwargs):
......@@ -322,26 +317,6 @@ def numba_funcify_DeepCopyOp(op, node, **kwargs):
return deepcopyop
@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba_njit
def shape(x):
return np.asarray(np.shape(x))
return shape
@numba_funcify.register(Shape_i)
def numba_funcify_Shape_i(op, **kwargs):
i = op.i
@numba_njit
def shape_i(x):
return np.asarray(np.shape(x)[i])
return shape_i
@numba_funcify.register(SortOp)
def numba_funcify_SortOp(op, node, **kwargs):
@numba_njit
......@@ -423,54 +398,6 @@ def direct_cast(typingctx, val, typ):
return sig, codegen
@numba_funcify.register(Reshape)
def numba_funcify_Reshape(op, **kwargs):
ndim = op.ndim
if ndim == 0:
@numba_njit
def reshape(x, shape):
return np.asarray(x.item())
else:
@numba_njit
def reshape(x, shape):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return np.reshape(
np.ascontiguousarray(np.asarray(x)),
numba_ndarray.to_fixed_tuple(shape, ndim),
)
return reshape
@numba_funcify.register(SpecifyShape)
def numba_funcify_SpecifyShape(op, node, **kwargs):
shape_inputs = node.inputs[1:]
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
func_conditions = [
f"assert x.shape[{i}] == {shape_input_names}"
for i, (shape_input, shape_input_names) in enumerate(
zip(shape_inputs, shape_input_names, strict=True)
)
if shape_input is not NoneConst
]
func = dedent(
f"""
def specify_shape(x, {create_arg_string(shape_input_names)}):
{"; ".join(func_conditions)}
return x
"""
)
specify_shape = compile_function_src(func, "specify_shape", globals())
return numba_njit(specify_shape)
def int_to_float_fn(inputs, out_dtype):
"""Create a Numba function that converts integer and boolean ``ndarray``s to floats."""
......
from textwrap import dedent
import numpy as np
from numba.np.unsafe import ndarray as numba_ndarray
from pytensor.link.numba.dispatch import numba_funcify
from pytensor.link.numba.dispatch.basic import create_arg_string, numba_njit
from pytensor.link.utils import compile_function_src
from pytensor.tensor import NoneConst
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
@numba_funcify.register(Shape)
def numba_funcify_Shape(op, **kwargs):
@numba_njit
def shape(x):
return np.asarray(np.shape(x))
return shape
@numba_funcify.register(Shape_i)
def numba_funcify_Shape_i(op, **kwargs):
i = op.i
@numba_njit
def shape_i(x):
return np.asarray(np.shape(x)[i])
return shape_i
@numba_funcify.register(SpecifyShape)
def numba_funcify_SpecifyShape(op, node, **kwargs):
shape_inputs = node.inputs[1:]
shape_input_names = ["shape_" + str(i) for i in range(len(shape_inputs))]
func_conditions = [
f"assert x.shape[{i}] == {shape_input_names}"
for i, (shape_input, shape_input_names) in enumerate(
zip(shape_inputs, shape_input_names, strict=True)
)
if shape_input is not NoneConst
]
func = dedent(
f"""
def specify_shape(x, {create_arg_string(shape_input_names)}):
{"; ".join(func_conditions)}
return x
"""
)
specify_shape = compile_function_src(func, "specify_shape", globals())
return numba_njit(specify_shape)
@numba_funcify.register(Reshape)
def numba_funcify_Reshape(op, **kwargs):
ndim = op.ndim
if ndim == 0:
@numba_njit
def reshape(x, shape):
return np.asarray(x.item())
else:
@numba_njit
def reshape(x, shape):
# TODO: Use this until https://github.com/numba/numba/issues/7353 is closed.
return np.reshape(
np.ascontiguousarray(np.asarray(x)),
numba_ndarray.to_fixed_tuple(shape, ndim),
)
return reshape
......@@ -31,7 +31,6 @@ from pytensor.raise_op import assert_op
from pytensor.scalar.basic import ScalarOp, as_scalar
from pytensor.tensor import blas, tensor
from pytensor.tensor.elemwise import Elemwise
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from pytensor.tensor.sort import ArgSortOp, SortOp
......@@ -332,22 +331,6 @@ def test_create_numba_signature(v, expected, force_scalar):
assert res == expected
@pytest.mark.parametrize(
"x, i",
[
(np.zeros((20, 3)), 1),
],
)
def test_Shape(x, i):
g = Shape()(pt.as_tensor_variable(x))
compare_numba_and_py([], [g], [])
g = Shape_i(i)(pt.as_tensor_variable(x))
compare_numba_and_py([], [g], [])
@pytest.mark.parametrize(
"x",
[
......@@ -412,81 +395,6 @@ def test_ArgSort(x, axis, kind, exc):
compare_numba_and_py([], [g], [])
@pytest.mark.parametrize(
"v, shape, ndim",
[
((pt.vector(), np.array([4], dtype=config.floatX)), ((), None), 0),
((pt.vector(), np.arange(4, dtype=config.floatX)), ((2, 2), None), 2),
(
(pt.vector(), np.arange(4, dtype=config.floatX)),
(pt.lvector(), np.array([2, 2], dtype="int64")),
2,
),
],
)
def test_Reshape(v, shape, ndim):
v, v_test_value = v
shape, shape_test_value = shape
g = Reshape(ndim)(v, shape)
inputs = [v] if not isinstance(shape, Variable) else [v, shape]
test_values = (
[v_test_value]
if not isinstance(shape, Variable)
else [v_test_value, shape_test_value]
)
compare_numba_and_py(
inputs,
[g],
test_values,
)
def test_Reshape_scalar():
v = pt.vector()
v_test_value = np.array([1.0], dtype=config.floatX)
g = Reshape(1)(v[0], (1,))
compare_numba_and_py(
[v],
g,
[v_test_value],
)
@pytest.mark.parametrize(
"v, shape, fails",
[
(
(pt.matrix(), np.array([[1.0]], dtype=config.floatX)),
(1, 1),
False,
),
(
(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, 1),
True,
),
(
(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, None),
False,
),
],
)
def test_SpecifyShape(v, shape, fails):
v, v_test_value = v
g = SpecifyShape()(v, *shape)
cm = contextlib.suppress() if not fails else pytest.raises(AssertionError)
with cm:
compare_numba_and_py(
[v],
[g],
[v_test_value],
)
def test_ViewOp():
v = pt.vector()
v_test_value = np.arange(4, dtype=config.floatX)
......
import contextlib
import numpy as np
import pytest
from pytensor import Variable, config
from pytensor import tensor as pt
from pytensor.tensor.shape import Reshape, Shape, Shape_i, SpecifyShape
from tests.link.numba.test_basic import compare_numba_and_py
@pytest.mark.parametrize(
"x, i",
[
(np.zeros((20, 3)), 1),
],
)
def test_Shape(x, i):
g = Shape()(pt.as_tensor_variable(x))
compare_numba_and_py([], [g], [])
g = Shape_i(i)(pt.as_tensor_variable(x))
compare_numba_and_py([], [g], [])
@pytest.mark.parametrize(
"v, shape, ndim",
[
((pt.vector(), np.array([4], dtype=config.floatX)), ((), None), 0),
((pt.vector(), np.arange(4, dtype=config.floatX)), ((2, 2), None), 2),
(
(pt.vector(), np.arange(4, dtype=config.floatX)),
(pt.lvector(), np.array([2, 2], dtype="int64")),
2,
),
],
)
def test_Reshape(v, shape, ndim):
v, v_test_value = v
shape, shape_test_value = shape
g = Reshape(ndim)(v, shape)
inputs = [v] if not isinstance(shape, Variable) else [v, shape]
test_values = (
[v_test_value]
if not isinstance(shape, Variable)
else [v_test_value, shape_test_value]
)
compare_numba_and_py(
inputs,
[g],
test_values,
)
def test_Reshape_scalar():
v = pt.vector()
v_test_value = np.array([1.0], dtype=config.floatX)
g = Reshape(1)(v[0], (1,))
compare_numba_and_py(
[v],
g,
[v_test_value],
)
@pytest.mark.parametrize(
"v, shape, fails",
[
(
(pt.matrix(), np.array([[1.0]], dtype=config.floatX)),
(1, 1),
False,
),
(
(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, 1),
True,
),
(
(pt.matrix(), np.array([[1.0, 2.0]], dtype=config.floatX)),
(1, None),
False,
),
],
)
def test_SpecifyShape(v, shape, fails):
v, v_test_value = v
g = SpecifyShape()(v, *shape)
cm = contextlib.suppress() if not fails else pytest.raises(AssertionError)
with cm:
compare_numba_and_py(
[v],
[g],
[v_test_value],
)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论