提交 5008fab7 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify and speedup numba DimShuffle implementation

上级 28fa7b76
......@@ -4,6 +4,7 @@ from textwrap import dedent, indent
import numba
import numpy as np
from numba.core.extending import overload
from numpy.lib.stride_tricks import as_strided
from pytensor.graph.op import Op
from pytensor.link.numba.dispatch import basic as numba_basic
......@@ -411,91 +412,38 @@ def numba_funcify_CAReduce(op, node, **kwargs):
@numba_funcify.register(DimShuffle)
def numba_funcify_DimShuffle(op, node, **kwargs):
shuffle = tuple(op.shuffle)
transposition = tuple(op.transposition)
augment = tuple(op.augment)
# We use `as_strided` to achieve the DimShuffle behavior of transposing and expanding/squezing dimensions in one call
# Numba doesn't currently support multiple expand/squeeze, and reshape is limited to contiguous arrays.
new_order = tuple(op._new_order)
shape_template = (1,) * node.outputs[0].ndim
strides_template = (0,) * node.outputs[0].ndim
ndim_new_shape = len(shuffle) + len(augment)
no_transpose = all(i == j for i, j in enumerate(transposition))
if no_transpose:
@numba_basic.numba_njit
def transpose(x):
return x
else:
@numba_basic.numba_njit
def transpose(x):
return np.transpose(x, transposition)
shape_template = (1,) * ndim_new_shape
# When `len(shuffle) == 0`, the `shuffle_shape[j]` expression below
# is typed as `getitem(Tuple(), int)`, which has no implementation
# (since getting an item from an empty sequence doesn't make sense).
# To avoid this compile-time error, we omit the expression altogether.
if len(shuffle) > 0:
# Use the statically known shape if available
if all(length is not None for length in node.outputs[0].type.shape):
shape = node.outputs[0].type.shape
if new_order == ():
# Special case needed because of https://github.com/numba/numba/issues/9933
@numba_basic.numba_njit
def find_shape(array_shape):
return shape
def squeeze_to_0d(x):
return as_strided(x, shape=(), strides=())
else:
@numba_basic.numba_njit
def find_shape(array_shape):
shape = shape_template
j = 0
for i in range(ndim_new_shape):
if i not in augment:
length = array_shape[j]
shape = numba_basic.tuple_setitem(shape, i, length)
j = j + 1
return shape
return squeeze_to_0d
else:
@numba_basic.numba_njit
def find_shape(array_shape):
return shape_template
if ndim_new_shape > 0:
@numba_basic.numba_njit
def dimshuffle_inner(x, shuffle):
x = transpose(x)
shuffle_shape = x.shape[: len(shuffle)]
new_shape = find_shape(shuffle_shape)
# FIXME: Numba's `array.reshape` only accepts C arrays.
return np.reshape(np.ascontiguousarray(x), new_shape)
else:
@numba_basic.numba_njit
def dimshuffle_inner(x, shuffle):
return np.reshape(np.ascontiguousarray(x), ())
# Without the following wrapper function we would see this error:
# E No implementation of function Function(<built-in function getitem>) found for signature:
# E
# E >>> getitem(UniTuple(int64 x 2), slice<a:b>)
# E
# E There are 22 candidate implementations:
# E - Of which 22 did not match due to:
# E Overload of function 'getitem': File: <numerous>: Line N/A.
# E With argument(s): '(UniTuple(int64 x 2), slice<a:b>)':
# E No match.
# ...(on this line)...
# E shuffle_shape = res.shape[: len(shuffle)]
@numba_basic.numba_njit(inline="always")
def dimshuffle(x):
return dimshuffle_inner(np.asarray(x), shuffle)
old_shape = x.shape
old_strides = x.strides
new_shape = shape_template
new_strides = strides_template
for i, o in enumerate(new_order):
if o != -1:
new_shape = numba_basic.tuple_setitem(new_shape, i, old_shape[o])
new_strides = numba_basic.tuple_setitem(
new_strides, i, old_strides[o]
)
return as_strided(x, shape=new_shape, strides=new_strides)
return dimshuffle
......
......@@ -23,6 +23,7 @@ from tests.link.numba.test_basic import (
from tests.tensor.test_elemwise import (
careduce_benchmark_tester,
check_elemwise_runtime_broadcast,
dimshuffle_benchmark,
)
......@@ -201,7 +202,7 @@ def test_Dimshuffle_returns_array():
def test_Dimshuffle_non_contiguous():
"""The numba impl of reshape doesn't work with
non-contiguous arrays, make sure we work around thpt."""
non-contiguous arrays, make sure we work around that."""
x = pt.dvector()
idx = pt.vector(dtype="int64")
op = DimShuffle(input_ndim=1, new_order=[])
......@@ -643,3 +644,7 @@ class TestsBenchmark:
return careduce_benchmark_tester(
axis, c_contiguous, mode="NUMBA", benchmark=benchmark
)
@pytest.mark.parametrize("c_contiguous", (True, False))
def test_dimshuffle(self, c_contiguous, benchmark):
dimshuffle_benchmark("NUMBA", c_contiguous, benchmark)
......@@ -66,6 +66,30 @@ def reduce_bitwise_and(x, axis=-1, dtype="int8"):
return np.apply_along_axis(custom_reduce, axis, x)
def dimshuffle_benchmark(mode, c_contiguous, benchmark):
x = tensor3("x")
if c_contiguous:
x_val = np.random.random((2, 3, 4)).astype(config.floatX)
else:
x_val = np.random.random((200, 300, 400)).transpose(1, 2, 0)
ys = [x.transpose(t) for t in itertools.permutations((0, 1, 2))]
ys += [
x[None],
x[:, None],
x[:, :, None],
x[:, :, :, None],
]
# Borrow to avoid deepcopy overhead
fn = pytensor.function(
[In(x, borrow=True)],
[Out(y, borrow=True) for y in ys],
mode=mode,
)
fn.trust_input = True
fn(x_val) # JIT compile for JIT backends
benchmark(fn, x_val)
class TestDimShuffle(unittest_tools.InferShapeTester):
op = DimShuffle
type = TensorType
......@@ -218,23 +242,9 @@ class TestDimShuffle(unittest_tools.InferShapeTester):
with pytest.raises(TypeError, match="input_ndim must be an integer"):
DimShuffle(input_ndim=(True, False), new_order=(1, 0))
def test_benchmark(self, benchmark):
x = tensor3("x")
x_val = np.random.random((2, 3, 4)).astype(config.floatX)
ys = [x.transpose(t) for t in itertools.permutations((0, 1, 2))]
ys += [
x[None],
x[:, None],
x[:, :, None],
x[:, :, :, None],
]
# Borrow to avoid deepcopy overhead
fn = pytensor.function(
[In(x, borrow=True)],
[Out(y, borrow=True) for y in ys],
)
fn.trust_input = True
benchmark(fn, x_val)
@pytest.mark.parametrize("c_contiguous", [True, False])
def test_benchmark(self, c_contiguous, benchmark):
dimshuffle_benchmark("FAST_RUN", c_contiguous, benchmark)
class TestBroadcast:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论