提交 51cda52b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Avoid numpy broadcast_to and ndindex in hot loops

上级 10105bea
......@@ -18,6 +18,7 @@ from pytensor.tensor.random.utils import (
broadcast_params,
normalize_size_param,
)
from pytensor.tensor.utils import faster_broadcast_to, faster_ndindex
# Scipy.stats is considerably slow to import
......@@ -976,19 +977,13 @@ class DirichletRV(RandomVariable):
@classmethod
def rng_fn(cls, rng, alphas, size):
if alphas.ndim > 1:
if size is None:
size = ()
size = tuple(np.atleast_1d(size))
if size:
alphas = np.broadcast_to(alphas, size + alphas.shape[-1:])
if size is not None:
alphas = faster_broadcast_to(alphas, size + alphas.shape[-1:])
samples_shape = alphas.shape
samples = np.empty(samples_shape)
for index in np.ndindex(*samples_shape[:-1]):
for index in faster_ndindex(samples_shape[:-1]):
samples[index] = rng.dirichlet(alphas[index])
return samples
else:
return rng.dirichlet(alphas, size=size)
......@@ -1800,11 +1795,11 @@ class MultinomialRV(RandomVariable):
if size is None:
n, p = broadcast_params([n, p], [0, 1])
else:
n = np.broadcast_to(n, size)
p = np.broadcast_to(p, size + p.shape[-1:])
n = faster_broadcast_to(n, size)
p = faster_broadcast_to(p, size + p.shape[-1:])
res = np.empty(p.shape, dtype=cls.dtype)
for idx in np.ndindex(p.shape[:-1]):
for idx in faster_ndindex(p.shape[:-1]):
res[idx] = rng.multinomial(n[idx], p[idx])
return res
else:
......@@ -1978,13 +1973,13 @@ class ChoiceWithoutReplacement(RandomVariable):
p.shape[:batch_ndim],
)
a = np.broadcast_to(a, size + a.shape[batch_ndim:])
a = faster_broadcast_to(a, size + a.shape[batch_ndim:])
if p is not None:
p = np.broadcast_to(p, size + p.shape[batch_ndim:])
p = faster_broadcast_to(p, size + p.shape[batch_ndim:])
a_indexed_shape = a.shape[len(size) + 1 :]
out = np.empty(size + core_shape + a_indexed_shape, dtype=a.dtype)
for idx in np.ndindex(size):
for idx in faster_ndindex(size):
out[idx] = rng.choice(
a[idx], p=None if p is None else p[idx], size=core_shape, replace=False
)
......@@ -2097,10 +2092,10 @@ class PermutationRV(RandomVariable):
if size is None:
size = x.shape[:batch_ndim]
else:
x = np.broadcast_to(x, size + x.shape[batch_ndim:])
x = faster_broadcast_to(x, size + x.shape[batch_ndim:])
out = np.empty(size + x.shape[batch_ndim:], dtype=x.dtype)
for idx in np.ndindex(size):
for idx in faster_ndindex(size):
out[idx] = rng.permutation(x[idx])
return out
......
......@@ -15,6 +15,7 @@ from pytensor.tensor.extra_ops import broadcast_arrays, broadcast_to
from pytensor.tensor.math import maximum
from pytensor.tensor.shape import shape_padleft, specify_shape
from pytensor.tensor.type import int_dtypes
from pytensor.tensor.utils import faster_broadcast_to
from pytensor.tensor.variable import TensorVariable
......@@ -125,7 +126,7 @@ def broadcast_params(
shapes = params_broadcast_shapes(
param_shapes, ndims_params, use_pytensor=use_pytensor
)
broadcast_to_fn = broadcast_to if use_pytensor else np.broadcast_to
broadcast_to_fn = broadcast_to if use_pytensor else faster_broadcast_to
# zip strict not specified because we are in a hot loop
bcast_params = [
......
import re
from collections.abc import Sequence
from itertools import product
from typing import cast
import numpy as np
from numpy import nditer
import pytensor
from pytensor.graph import FunctionGraph, Variable
......@@ -233,3 +235,24 @@ def normalize_reduce_axis(axis, ndim: int) -> tuple[int, ...] | None:
# TODO: If axis tuple is equivalent to None, return None for more canonicalization?
return cast(tuple, axis)
def faster_broadcast_to(x, shape):
# Stripped down core logic of `np.broadcast_to`
return nditer(
(x,),
flags=["multi_index", "zerosize_ok"],
op_flags=["readonly"],
itershape=shape,
order="C",
).itviews[0]
def faster_ndindex(shape: Sequence[int]):
"""Equivalent to `np.ndindex` but usually 10x faster.
Unlike `np.ndindex`, this function expects a single sequence of integers
https://github.com/numpy/numpy/issues/28921
"""
return product(*(range(s) for s in shape))
......@@ -746,9 +746,8 @@ test_mvnormal_cov_decomposition_method = create_mvnormal_cov_decomposition_metho
],
)
def test_dirichlet_samples(alphas, size):
def dirichlet_test_fn(mean=None, cov=None, size=None, random_state=None):
if size is None:
size = ()
# FIXME: Is this just testing itself against itself?
def dirichlet_test_fn(alphas, size, random_state):
return dirichlet.rng_fn(random_state, alphas, size)
compare_sample_values(dirichlet, alphas, size=size, test_fn=dirichlet_test_fn)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论