Unverified 提交 b3da2a4b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: GitHub

Lazy import of scipy.stats (#1268)

上级 00fea0e3
...@@ -3,7 +3,6 @@ import warnings ...@@ -3,7 +3,6 @@ import warnings
from typing import Literal from typing import Literal
import numpy as np import numpy as np
import scipy.stats as stats
from numpy import broadcast_shapes as np_broadcast_shapes from numpy import broadcast_shapes as np_broadcast_shapes
from numpy import einsum as np_einsum from numpy import einsum as np_einsum
from numpy import sqrt as np_sqrt from numpy import sqrt as np_sqrt
...@@ -21,6 +20,11 @@ from pytensor.tensor.random.utils import ( ...@@ -21,6 +20,11 @@ from pytensor.tensor.random.utils import (
) )
# Scipy.stats is considerably slow to import
# We import scipy.stats lazily inside `ScipyRandomVariable`
stats = None
try: try:
broadcast_shapes = np.broadcast_shapes broadcast_shapes = np.broadcast_shapes
except AttributeError: except AttributeError:
...@@ -57,6 +61,9 @@ class ScipyRandomVariable(RandomVariable): ...@@ -57,6 +61,9 @@ class ScipyRandomVariable(RandomVariable):
@classmethod @classmethod
def rng_fn(cls, *args, **kwargs): def rng_fn(cls, *args, **kwargs):
global stats
if stats is None:
import scipy.stats as stats
size = args[-1] size = args[-1]
res = cls.rng_fn_scipy(*args, **kwargs) res = cls.rng_fn_scipy(*args, **kwargs)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论