提交 51cda52b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid numpy broadcast_to and ndindex in hot loops

上级 10105bea
...@@ -18,6 +18,7 @@ from pytensor.tensor.random.utils import ( ...@@ -18,6 +18,7 @@ from pytensor.tensor.random.utils import (
broadcast_params, broadcast_params,
normalize_size_param, normalize_size_param,
) )
from pytensor.tensor.utils import faster_broadcast_to, faster_ndindex
# Scipy.stats is considerably slow to import # Scipy.stats is considerably slow to import
...@@ -976,19 +977,13 @@ class DirichletRV(RandomVariable): ...@@ -976,19 +977,13 @@ class DirichletRV(RandomVariable):
@classmethod @classmethod
def rng_fn(cls, rng, alphas, size): def rng_fn(cls, rng, alphas, size):
if alphas.ndim > 1: if alphas.ndim > 1:
if size is None: if size is not None:
size = () alphas = faster_broadcast_to(alphas, size + alphas.shape[-1:])
size = tuple(np.atleast_1d(size))
if size:
alphas = np.broadcast_to(alphas, size + alphas.shape[-1:])
samples_shape = alphas.shape samples_shape = alphas.shape
samples = np.empty(samples_shape) samples = np.empty(samples_shape)
for index in np.ndindex(*samples_shape[:-1]): for index in faster_ndindex(samples_shape[:-1]):
samples[index] = rng.dirichlet(alphas[index]) samples[index] = rng.dirichlet(alphas[index])
return samples return samples
else: else:
return rng.dirichlet(alphas, size=size) return rng.dirichlet(alphas, size=size)
...@@ -1800,11 +1795,11 @@ class MultinomialRV(RandomVariable): ...@@ -1800,11 +1795,11 @@ class MultinomialRV(RandomVariable):
if size is None: if size is None:
n, p = broadcast_params([n, p], [0, 1]) n, p = broadcast_params([n, p], [0, 1])
else: else:
n = np.broadcast_to(n, size) n = faster_broadcast_to(n, size)
p = np.broadcast_to(p, size + p.shape[-1:]) p = faster_broadcast_to(p, size + p.shape[-1:])
res = np.empty(p.shape, dtype=cls.dtype) res = np.empty(p.shape, dtype=cls.dtype)
for idx in np.ndindex(p.shape[:-1]): for idx in faster_ndindex(p.shape[:-1]):
res[idx] = rng.multinomial(n[idx], p[idx]) res[idx] = rng.multinomial(n[idx], p[idx])
return res return res
else: else:
...@@ -1978,13 +1973,13 @@ class ChoiceWithoutReplacement(RandomVariable): ...@@ -1978,13 +1973,13 @@ class ChoiceWithoutReplacement(RandomVariable):
p.shape[:batch_ndim], p.shape[:batch_ndim],
) )
a = np.broadcast_to(a, size + a.shape[batch_ndim:]) a = faster_broadcast_to(a, size + a.shape[batch_ndim:])
if p is not None: if p is not None:
p = np.broadcast_to(p, size + p.shape[batch_ndim:]) p = faster_broadcast_to(p, size + p.shape[batch_ndim:])
a_indexed_shape = a.shape[len(size) + 1 :] a_indexed_shape = a.shape[len(size) + 1 :]
out = np.empty(size + core_shape + a_indexed_shape, dtype=a.dtype) out = np.empty(size + core_shape + a_indexed_shape, dtype=a.dtype)
for idx in np.ndindex(size): for idx in faster_ndindex(size):
out[idx] = rng.choice( out[idx] = rng.choice(
a[idx], p=None if p is None else p[idx], size=core_shape, replace=False a[idx], p=None if p is None else p[idx], size=core_shape, replace=False
) )
...@@ -2097,10 +2092,10 @@ class PermutationRV(RandomVariable): ...@@ -2097,10 +2092,10 @@ class PermutationRV(RandomVariable):
if size is None: if size is None:
size = x.shape[:batch_ndim] size = x.shape[:batch_ndim]
else: else:
x = np.broadcast_to(x, size + x.shape[batch_ndim:]) x = faster_broadcast_to(x, size + x.shape[batch_ndim:])
out = np.empty(size + x.shape[batch_ndim:], dtype=x.dtype) out = np.empty(size + x.shape[batch_ndim:], dtype=x.dtype)
for idx in np.ndindex(size): for idx in faster_ndindex(size):
out[idx] = rng.permutation(x[idx]) out[idx] = rng.permutation(x[idx])
return out return out
......
...@@ -15,6 +15,7 @@ from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to ...@@ -15,6 +15,7 @@ from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
from pytensor.tensor.math import maximum from pytensor.tensor.math import maximum
from pytensor.tensor.shape import shape_padleft, specify_shape from pytensor.tensor.shape import shape_padleft, specify_shape
from pytensor.tensor.type import int_dtypes from pytensor.tensor.type import int_dtypes
from pytensor.tensor.utils import faster_broadcast_to
from pytensor.tensor.variable import TensorVariable from pytensor.tensor.variable import TensorVariable
...@@ -125,7 +126,7 @@ def broadcast_params( ...@@ -125,7 +126,7 @@ def broadcast_params(
shapes = params_broadcast_shapes( shapes = params_broadcast_shapes(
param_shapes, ndims_params, use_pytensor=use_pytensor param_shapes, ndims_params, use_pytensor=use_pytensor
) )
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to broadcast_to_fn = broadcast_to if use_pytensor else faster_broadcast_to
# zip strict not specified because we are in a hot loop # zip strict not specified because we are in a hot loop
bcast_params = [ bcast_params = [
......
import re import re
from collections.abc import Sequence from collections.abc import Sequence
from itertools import product
from typing import cast from typing import cast
import numpy as np import numpy as np
from numpy import nditer
import pytensor import pytensor
from pytensor.graph import FunctionGraph, Variable from pytensor.graph import FunctionGraph, Variable
...@@ -233,3 +235,24 @@ def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None: ...@@ -233,3 +235,24 @@ def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None:
# TODO: If axis tuple is equivalent to None, return None for more canonicalization? # TODO: If axis tuple is equivalent to None, return None for more canonicalization?
return cast(tuple, axis) return cast(tuple, axis)
def faster_broadcast_to(x, shape):
# Stripped down core logic of `np.broadcast_to`
return nditer(
(x,),
flags=["multi_index", "zerosize_ok"],
op_flags=["readonly"],
itershape=shape,
order="C",
).itviews[0]
def faster_ndindex(shape: Sequence[int]):
"""Equivalent to `np.ndindex` but usually 10x faster.
Unlike `np.ndindex`, this function expects a single sequence of integers
https://github.com/numpy/numpy/issues/28921
"""
return product(*(range(s) for s in shape))
...@@ -746,9 +746,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho ...@@ -746,9 +746,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
], ],
) )
def test_dirichlet_samples(alphas, size): def test_dirichlet_samples(alphas, size):
def dirichlet_test_fn(mean=None, cov=None, size=None, random_state=None): # FIXME: Is this just testing itself against itself?
if size is None: def dirichlet_test_fn(alphas, size, random_state):
size = ()
return dirichlet.rng_fn(random_state, alphas, size) return dirichlet.rng_fn(random_state, alphas, size)
compare_sample_values(dirichlet, alphas, size=size, test_fn=dirichlet_test_fn) compare_sample_values(dirichlet, alphas, size=size, test_fn=dirichlet_test_fn)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论