提交 5008fab7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify and speedup numba DimShuffle implementation

上级 28fa7b76
...@@ -4,6 +4,7 @@ from textwrap import dedent, indent ...@@ -4,6 +4,7 @@ from textwrap import dedent, indent
import numba import numba
import numpy as np import numpy as np
from numba.core.extending import overload from numba.core.extending import overload
from numpy.lib.stride_tricks import as_strided
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic from pytensor.link.numba.dispatch import basic as numba_basic
...@@ -411,91 +412,38 @@ def numba_funcify_CAReduce(op, node, **kwargs): ...@@ -411,91 +412,38 @@ def numba_funcify_CAReduce(op, node, **kwargs):
@numba_funcify.register(DimShuffle) @numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, node, **kwargs): def numba_funcify_DimShuffle(op, node, **kwargs):
shuffle = tuple(op.shuffle) # We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call
transposition = tuple(op.transposition) # Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays.
augment = tuple(op.augment) new_order = tuple(op._new_order)
shape_template = (1,) * node.outputs[0].ndim
strides_template = (0,) * node.outputs[0].ndim
ndim_new_shape = len(shuffle) + len(augment) if new_order == ():
# Special case needed because of https://github.com/numba/numba/issues/9933
no_transpose = all(i == j for i, j in enumerate(transposition))
if no_transpose:
@numba_basic.numba_njit
def transpose(x):
return x
else:
@numba_basic.numba_njit
def transpose(x):
return np.transpose(x, transposition)
shape_template = (1,) * ndim_new_shape
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression below
# is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether.
if len(shuffle) > 0:
# Use the statically known shape if available
if all(length is not None for length in node.outputs[0].type.shape):
shape = node.outputs[0].type.shape
@numba_basic.numba_njit
def find_shape(array_shape):
return shape
else:
@numba_basic.numba_njit
def find_shape(array_shape):
shape = shape_template
j = 0
for i in range(ndim_new_shape):
if i not in augment:
length = array_shape[j]
shape = numba_basic.tuple_setitem(shape, i, length)
j = j + 1
return shape
else:
@numba_basic.numba_njit
def find_shape(array_shape):
return shape_template
if ndim_new_shape > 0:
@numba_basic.numba_njit @numba_basic.numba_njit
def dimshuffle_inner(x, shuffle): def squeeze_to_0d(x):
x = transpose(x) return as_strided(x, shape=(), strides=())
shuffle_shape = x.shape[: len(shuffle)]
new_shape = find_shape(shuffle_shape)
# FIXME: Numba's `array.reshape` only accepts C arrays. return squeeze_to_0d
return np.reshape(np.ascontiguousarray(x), new_shape)
else: else:
@numba_basic.numba_njit @numba_basic.numba_njit
def dimshuffle_inner(x, shuffle): def dimshuffle(x):
return np.reshape(np.ascontiguousarray(x), ()) old_shape = x.shape
old_strides = x.strides
# Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature: new_shape = shape_template
# E new_strides = strides_template
# E >>> getitem(UniTuple(int64 x 2), slice<a:b>) for i, o in enumerate(new_order):
# E if o != -1:
# E There are 22 candidate implementations: new_shape = numba_basic.tuple_setitem(new_shape, i, old_shape[o])
# E - Of which 22 did not match due to: new_strides = numba_basic.tuple_setitem(
# E Overload of function 'getitem': File: <numerous>: Line N/A. new_strides, i, old_strides[o]
# E With argument(s): '(UniTuple(int64 x 2), slice<a:b>)': )
# E No match.
# ...(on this line)... return as_strided(x, shape=new_shape, strides=new_strides)
# E shuffle_shape = res.shape[: len(shuffle)]
@numba_basic.numba_njit(inline="always")
def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle)
return dimshuffle return dimshuffle
......
...@@ -23,6 +23,7 @@ from tests.link.numba.test_basic import ( ...@@ -23,6 +23,7 @@ from tests.link.numba.test_basic import (
from tests.tensor.test_elemwise import ( from tests.tensor.test_elemwise import (
careduce_benchmark_tester, careduce_benchmark_tester,
check_elemwise_runtime_broadcast, check_elemwise_runtime_broadcast,
dimshuffle_benchmark,
) )
...@@ -201,7 +202,7 @@ def test_Dimshuffle_returns_array(): ...@@ -201,7 +202,7 @@ def test_Dimshuffle_returns_array():
def test_Dimshuffle_non_contiguous(): def test_Dimshuffle_non_contiguous():
"""The numba impl of reshape doesn't work with """The numba impl of reshape doesn't work with
non-contiguous arrays, make sure we work around thpt.""" non-contiguous arrays, make sure we work around that."""
x = pt.dvector() x = pt.dvector()
idx = pt.vector(dtype="int64") idx = pt.vector(dtype="int64")
op = DimShuffle(input_ndim=1, new_order=[]) op = DimShuffle(input_ndim=1, new_order=[])
...@@ -643,3 +644,7 @@ class TestsBenchmark: ...@@ -643,3 +644,7 @@ class TestsBenchmark:
return careduce_benchmark_tester( return careduce_benchmark_tester(
axis, c_contiguous, mode="NUMBA", benchmark=benchmark axis, c_contiguous, mode="NUMBA", benchmark=benchmark
) )
@pytest.mark.parametrize("c_contiguous", (True, False))
def test_dimshuffle(self, c_contiguous, benchmark):
dimshuffle_benchmark("NUMBA", c_contiguous, benchmark)
...@@ -66,6 +66,30 @@ def reduce_bitwise_and(x, axis=-1, dtype="int8"): ...@@ -66,6 +66,30 @@ def reduce_bitwise_and(x, axis=-1, dtype="int8"):
return np.apply_along_axis(custom_reduce, axis, x) return np.apply_along_axis(custom_reduce, axis, x)
def dimshuffle_benchmark(mode, c_contiguous, benchmark):
x = tensor3("x")
if c_contiguous:
x_val = np.random.random((2, 3, 4)).astype(config.floatX)
else:
x_val = np.random.random((200, 300, 400)).transpose(1, 2, 0)
ys = [x.transpose(t) for t in itertools.permutations((0, 1, 2))]
ys += [
x[None],
x[:, None],
x[:, :, None],
x[:, :, :, None],
]
# Borrow to avoid deepcopy overhead
fn = pytensor.function(
[In(x, borrow=True)],
[Out(y, borrow=True) for y in ys],
mode=mode,
)
fn.trust_input = True
fn(x_val) # JIT compile for JIT backends
benchmark(fn, x_val)
class TestDimShuffle(unittest_tools.InferShapeTester): class TestDimShuffle(unittest_tools.InferShapeTester):
op = DimShuffle op = DimShuffle
type = TensorType type = TensorType
...@@ -218,23 +242,9 @@ class TestDimShuffle(unittest_tools.InferShapeTester): ...@@ -218,23 +242,9 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
with pytest.raises(TypeError, match="input_ndim must be an integer"): with pytest.raises(TypeError, match="input_ndim must be an integer"):
DimShuffle(input_ndim=(True, False), new_order=(1, 0)) DimShuffle(input_ndim=(True, False), new_order=(1, 0))
def test_benchmark(self, benchmark): @pytest.mark.parametrize("c_contiguous", [True, False])
x = tensor3("x") def test_benchmark(self, c_contiguous, benchmark):
x_val = np.random.random((2, 3, 4)).astype(config.floatX) dimshuffle_benchmark("FAST_RUN", c_contiguous, benchmark)
ys = [x.transpose(t) for t in itertools.permutations((0, 1, 2))]
ys += [
x[None],
x[:, None],
x[:, :, None],
x[:, :, :, None],
]
# Borrow to avoid deepcopy overhead
fn = pytensor.function(
[In(x, borrow=True)],
[Out(y, borrow=True) for y in ys],
)
fn.trust_input = True
benchmark(fn, x_val)
class TestBroadcast: class TestBroadcast:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论