Unverified 提交 b3da2a4b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Lazy import of scipy.stats (#1268)

上级 00fea0e3
......@@ -3,7 +3,6 @@ import warnings
from typing import Literal
import numpy as np
import scipy.stats as stats
from numpy import broadcast_shapes as np_broadcast_shapes
from numpy import einsum as np_einsum
from numpy import sqrt as np_sqrt
......@@ -21,6 +20,11 @@ from pytensor.tensor.random.utils import (
)
# Scipy.stats is considerably slow to import
# We import scipy.stats lazily inside `ScipyRandomVariable`
stats = None
try:
broadcast_shapes = np.broadcast_shapes
except AttributeError:
......@@ -57,6 +61,9 @@ class ScipyRandomVariable(RandomVariable):
@classmethod
def rng_fn(cls, *args, **kwargs):
global stats
if stats is None:
import scipy.stats as stats
size = args[-1]
res = cls.rng_fn_scipy(*args, **kwargs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论