Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
3e9c6a4f
提交
3e9c6a4f
authored
4月 23, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Introduce signature instead of ndim_supp and ndims_params
上级
a576fa2c
隐藏空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
198 行增加
和
222 行删除
+198
-222
basic.py
pytensor/tensor/random/basic.py
+57
-108
op.py
pytensor/tensor/random/op.py
+90
-29
jax.py
pytensor/tensor/random/rewriting/jax.py
+5
-3
utils.py
pytensor/tensor/random/utils.py
+7
-2
utils.py
pytensor/tensor/utils.py
+4
-1
test_random.py
tests/link/jax/test_random.py
+2
-4
test_basic.py
tests/tensor/random/rewriting/test_basic.py
+22
-45
test_basic.py
tests/tensor/random/test_basic.py
+5
-21
test_op.py
tests/tensor/random/test_op.py
+6
-9
没有找到文件。
pytensor/tensor/random/basic.py
浏览文件 @
3e9c6a4f
...
@@ -13,7 +13,6 @@ from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
...
@@ -13,7 +13,6 @@ from pytensor.tensor.random.type import RandomGeneratorType, RandomStateType
from
pytensor.tensor.random.utils
import
(
from
pytensor.tensor.random.utils
import
(
broadcast_params
,
broadcast_params
,
normalize_size_param
,
normalize_size_param
,
supp_shape_from_ref_param_shape
,
)
)
from
pytensor.tensor.random.var
import
(
from
pytensor.tensor.random.var
import
(
RandomGeneratorSharedVariable
,
RandomGeneratorSharedVariable
,
...
@@ -91,8 +90,7 @@ class UniformRV(RandomVariable):
...
@@ -91,8 +90,7 @@ class UniformRV(RandomVariable):
"""
"""
name
=
"uniform"
name
=
"uniform"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Uniform"
,
"
\\
operatorname{Uniform}"
)
_print_name
=
(
"Uniform"
,
"
\\
operatorname{Uniform}"
)
...
@@ -146,8 +144,7 @@ class TriangularRV(RandomVariable):
...
@@ -146,8 +144,7 @@ class TriangularRV(RandomVariable):
"""
"""
name
=
"triangular"
name
=
"triangular"
ndim_supp
=
0
signature
=
"(),(),()->()"
ndims_params
=
[
0
,
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Triangular"
,
"
\\
operatorname{Triangular}"
)
_print_name
=
(
"Triangular"
,
"
\\
operatorname{Triangular}"
)
...
@@ -202,8 +199,7 @@ class BetaRV(RandomVariable):
...
@@ -202,8 +199,7 @@ class BetaRV(RandomVariable):
"""
"""
name
=
"beta"
name
=
"beta"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Beta"
,
"
\\
operatorname{Beta}"
)
_print_name
=
(
"Beta"
,
"
\\
operatorname{Beta}"
)
...
@@ -249,8 +245,7 @@ class NormalRV(RandomVariable):
...
@@ -249,8 +245,7 @@ class NormalRV(RandomVariable):
"""
"""
name
=
"normal"
name
=
"normal"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Normal"
,
"
\\
operatorname{Normal}"
)
_print_name
=
(
"Normal"
,
"
\\
operatorname{Normal}"
)
...
@@ -316,8 +311,7 @@ class HalfNormalRV(ScipyRandomVariable):
...
@@ -316,8 +311,7 @@ class HalfNormalRV(ScipyRandomVariable):
"""
"""
name
=
"halfnormal"
name
=
"halfnormal"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"HalfNormal"
,
"
\\
operatorname{HalfNormal}"
)
_print_name
=
(
"HalfNormal"
,
"
\\
operatorname{HalfNormal}"
)
...
@@ -382,8 +376,7 @@ class LogNormalRV(RandomVariable):
...
@@ -382,8 +376,7 @@ class LogNormalRV(RandomVariable):
"""
"""
name
=
"lognormal"
name
=
"lognormal"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"LogNormal"
,
"
\\
operatorname{LogNormal}"
)
_print_name
=
(
"LogNormal"
,
"
\\
operatorname{LogNormal}"
)
...
@@ -434,8 +427,7 @@ class GammaRV(RandomVariable):
...
@@ -434,8 +427,7 @@ class GammaRV(RandomVariable):
"""
"""
name
=
"gamma"
name
=
"gamma"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Gamma"
,
"
\\
operatorname{Gamma}"
)
_print_name
=
(
"Gamma"
,
"
\\
operatorname{Gamma}"
)
...
@@ -567,8 +559,7 @@ class ParetoRV(ScipyRandomVariable):
...
@@ -567,8 +559,7 @@ class ParetoRV(ScipyRandomVariable):
"""
"""
name
=
"pareto"
name
=
"pareto"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Pareto"
,
"
\\
operatorname{Pareto}"
)
_print_name
=
(
"Pareto"
,
"
\\
operatorname{Pareto}"
)
...
@@ -618,8 +609,7 @@ class GumbelRV(ScipyRandomVariable):
...
@@ -618,8 +609,7 @@ class GumbelRV(ScipyRandomVariable):
"""
"""
name
=
"gumbel"
name
=
"gumbel"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Gumbel"
,
"
\\
operatorname{Gumbel}"
)
_print_name
=
(
"Gumbel"
,
"
\\
operatorname{Gumbel}"
)
...
@@ -680,8 +670,7 @@ class ExponentialRV(RandomVariable):
...
@@ -680,8 +670,7 @@ class ExponentialRV(RandomVariable):
"""
"""
name
=
"exponential"
name
=
"exponential"
ndim_supp
=
0
signature
=
"()->()"
ndims_params
=
[
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Exponential"
,
"
\\
operatorname{Exponential}"
)
_print_name
=
(
"Exponential"
,
"
\\
operatorname{Exponential}"
)
...
@@ -724,8 +713,7 @@ class WeibullRV(RandomVariable):
...
@@ -724,8 +713,7 @@ class WeibullRV(RandomVariable):
"""
"""
name
=
"weibull"
name
=
"weibull"
ndim_supp
=
0
signature
=
"()->()"
ndims_params
=
[
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Weibull"
,
"
\\
operatorname{Weibull}"
)
_print_name
=
(
"Weibull"
,
"
\\
operatorname{Weibull}"
)
...
@@ -769,8 +757,7 @@ class LogisticRV(RandomVariable):
...
@@ -769,8 +757,7 @@ class LogisticRV(RandomVariable):
"""
"""
name
=
"logistic"
name
=
"logistic"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Logistic"
,
"
\\
operatorname{Logistic}"
)
_print_name
=
(
"Logistic"
,
"
\\
operatorname{Logistic}"
)
...
@@ -818,8 +805,7 @@ class VonMisesRV(RandomVariable):
...
@@ -818,8 +805,7 @@ class VonMisesRV(RandomVariable):
"""
"""
name
=
"vonmises"
name
=
"vonmises"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"VonMises"
,
"
\\
operatorname{VonMises}"
)
_print_name
=
(
"VonMises"
,
"
\\
operatorname{VonMises}"
)
...
@@ -886,19 +872,10 @@ class MvNormalRV(RandomVariable):
...
@@ -886,19 +872,10 @@ class MvNormalRV(RandomVariable):
"""
"""
name
=
"multivariate_normal"
name
=
"multivariate_normal"
ndim_supp
=
1
signature
=
"(n),(n,n)->(n)"
ndims_params
=
[
1
,
2
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"MultivariateNormal"
,
"
\\
operatorname{MultivariateNormal}"
)
_print_name
=
(
"MultivariateNormal"
,
"
\\
operatorname{MultivariateNormal}"
)
def
_supp_shape_from_params
(
self
,
dist_params
,
param_shapes
=
None
):
return
supp_shape_from_ref_param_shape
(
ndim_supp
=
self
.
ndim_supp
,
dist_params
=
dist_params
,
param_shapes
=
param_shapes
,
ref_param_idx
=
0
,
)
def
__call__
(
self
,
mean
=
None
,
cov
=
None
,
size
=
None
,
**
kwargs
):
def
__call__
(
self
,
mean
=
None
,
cov
=
None
,
size
=
None
,
**
kwargs
):
r""" "Draw samples from a multivariate normal distribution.
r""" "Draw samples from a multivariate normal distribution.
...
@@ -942,7 +919,7 @@ class MvNormalRV(RandomVariable):
...
@@ -942,7 +919,7 @@ class MvNormalRV(RandomVariable):
mean
=
np
.
broadcast_to
(
mean
,
size
+
mean
.
shape
[
-
1
:])
mean
=
np
.
broadcast_to
(
mean
,
size
+
mean
.
shape
[
-
1
:])
cov
=
np
.
broadcast_to
(
cov
,
size
+
cov
.
shape
[
-
2
:])
cov
=
np
.
broadcast_to
(
cov
,
size
+
cov
.
shape
[
-
2
:])
else
:
else
:
mean
,
cov
=
broadcast_params
([
mean
,
cov
],
cls
.
ndims_params
)
mean
,
cov
=
broadcast_params
([
mean
,
cov
],
[
1
,
2
]
)
res
=
np
.
empty
(
mean
.
shape
)
res
=
np
.
empty
(
mean
.
shape
)
for
idx
in
np
.
ndindex
(
mean
.
shape
[:
-
1
]):
for
idx
in
np
.
ndindex
(
mean
.
shape
[:
-
1
]):
...
@@ -973,19 +950,10 @@ class DirichletRV(RandomVariable):
...
@@ -973,19 +950,10 @@ class DirichletRV(RandomVariable):
"""
"""
name
=
"dirichlet"
name
=
"dirichlet"
ndim_supp
=
1
signature
=
"(a)->(a)"
ndims_params
=
[
1
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Dirichlet"
,
"
\\
operatorname{Dirichlet}"
)
_print_name
=
(
"Dirichlet"
,
"
\\
operatorname{Dirichlet}"
)
def
_supp_shape_from_params
(
self
,
dist_params
,
param_shapes
=
None
):
return
supp_shape_from_ref_param_shape
(
ndim_supp
=
self
.
ndim_supp
,
dist_params
=
dist_params
,
param_shapes
=
param_shapes
,
ref_param_idx
=
0
,
)
def
__call__
(
self
,
alphas
,
size
=
None
,
**
kwargs
):
def
__call__
(
self
,
alphas
,
size
=
None
,
**
kwargs
):
r"""Draw samples from a dirichlet distribution.
r"""Draw samples from a dirichlet distribution.
...
@@ -1047,8 +1015,7 @@ class PoissonRV(RandomVariable):
...
@@ -1047,8 +1015,7 @@ class PoissonRV(RandomVariable):
"""
"""
name
=
"poisson"
name
=
"poisson"
ndim_supp
=
0
signature
=
"()->()"
ndims_params
=
[
0
]
dtype
=
"int64"
dtype
=
"int64"
_print_name
=
(
"Poisson"
,
"
\\
operatorname{Poisson}"
)
_print_name
=
(
"Poisson"
,
"
\\
operatorname{Poisson}"
)
...
@@ -1093,8 +1060,7 @@ class GeometricRV(RandomVariable):
...
@@ -1093,8 +1060,7 @@ class GeometricRV(RandomVariable):
"""
"""
name
=
"geometric"
name
=
"geometric"
ndim_supp
=
0
signature
=
"()->()"
ndims_params
=
[
0
]
dtype
=
"int64"
dtype
=
"int64"
_print_name
=
(
"Geometric"
,
"
\\
operatorname{Geometric}"
)
_print_name
=
(
"Geometric"
,
"
\\
operatorname{Geometric}"
)
...
@@ -1136,8 +1102,7 @@ class HyperGeometricRV(RandomVariable):
...
@@ -1136,8 +1102,7 @@ class HyperGeometricRV(RandomVariable):
"""
"""
name
=
"hypergeometric"
name
=
"hypergeometric"
ndim_supp
=
0
signature
=
"(),(),()->()"
ndims_params
=
[
0
,
0
,
0
]
dtype
=
"int64"
dtype
=
"int64"
_print_name
=
(
"HyperGeometric"
,
"
\\
operatorname{HyperGeometric}"
)
_print_name
=
(
"HyperGeometric"
,
"
\\
operatorname{HyperGeometric}"
)
...
@@ -1185,8 +1150,7 @@ class CauchyRV(ScipyRandomVariable):
...
@@ -1185,8 +1150,7 @@ class CauchyRV(ScipyRandomVariable):
"""
"""
name
=
"cauchy"
name
=
"cauchy"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Cauchy"
,
"
\\
operatorname{Cauchy}"
)
_print_name
=
(
"Cauchy"
,
"
\\
operatorname{Cauchy}"
)
...
@@ -1236,8 +1200,7 @@ class HalfCauchyRV(ScipyRandomVariable):
...
@@ -1236,8 +1200,7 @@ class HalfCauchyRV(ScipyRandomVariable):
"""
"""
name
=
"halfcauchy"
name
=
"halfcauchy"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"HalfCauchy"
,
"
\\
operatorname{HalfCauchy}"
)
_print_name
=
(
"HalfCauchy"
,
"
\\
operatorname{HalfCauchy}"
)
...
@@ -1291,8 +1254,7 @@ class InvGammaRV(ScipyRandomVariable):
...
@@ -1291,8 +1254,7 @@ class InvGammaRV(ScipyRandomVariable):
"""
"""
name
=
"invgamma"
name
=
"invgamma"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"InverseGamma"
,
"
\\
operatorname{InverseGamma}"
)
_print_name
=
(
"InverseGamma"
,
"
\\
operatorname{InverseGamma}"
)
...
@@ -1342,8 +1304,7 @@ class WaldRV(RandomVariable):
...
@@ -1342,8 +1304,7 @@ class WaldRV(RandomVariable):
"""
"""
name
=
"wald"
name
=
"wald"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name_
=
(
"Wald"
,
"
\\
operatorname{Wald}"
)
_print_name_
=
(
"Wald"
,
"
\\
operatorname{Wald}"
)
...
@@ -1390,8 +1351,7 @@ class TruncExponentialRV(ScipyRandomVariable):
...
@@ -1390,8 +1351,7 @@ class TruncExponentialRV(ScipyRandomVariable):
"""
"""
name
=
"truncexpon"
name
=
"truncexpon"
ndim_supp
=
0
signature
=
"(),(),()->()"
ndims_params
=
[
0
,
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"TruncatedExponential"
,
"
\\
operatorname{TruncatedExponential}"
)
_print_name
=
(
"TruncatedExponential"
,
"
\\
operatorname{TruncatedExponential}"
)
...
@@ -1446,8 +1406,7 @@ class StudentTRV(ScipyRandomVariable):
...
@@ -1446,8 +1406,7 @@ class StudentTRV(ScipyRandomVariable):
"""
"""
name
=
"t"
name
=
"t"
ndim_supp
=
0
signature
=
"(),(),()->()"
ndims_params
=
[
0
,
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"StudentT"
,
"
\\
operatorname{StudentT}"
)
_print_name
=
(
"StudentT"
,
"
\\
operatorname{StudentT}"
)
...
@@ -1506,8 +1465,7 @@ class BernoulliRV(ScipyRandomVariable):
...
@@ -1506,8 +1465,7 @@ class BernoulliRV(ScipyRandomVariable):
"""
"""
name
=
"bernoulli"
name
=
"bernoulli"
ndim_supp
=
0
signature
=
"()->()"
ndims_params
=
[
0
]
dtype
=
"int64"
dtype
=
"int64"
_print_name
=
(
"Bernoulli"
,
"
\\
operatorname{Bernoulli}"
)
_print_name
=
(
"Bernoulli"
,
"
\\
operatorname{Bernoulli}"
)
...
@@ -1554,8 +1512,7 @@ class LaplaceRV(RandomVariable):
...
@@ -1554,8 +1512,7 @@ class LaplaceRV(RandomVariable):
"""
"""
name
=
"laplace"
name
=
"laplace"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"Laplace"
,
"
\\
operatorname{Laplace}"
)
_print_name
=
(
"Laplace"
,
"
\\
operatorname{Laplace}"
)
...
@@ -1601,8 +1558,7 @@ class BinomialRV(RandomVariable):
...
@@ -1601,8 +1558,7 @@ class BinomialRV(RandomVariable):
"""
"""
name
=
"binomial"
name
=
"binomial"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"int64"
dtype
=
"int64"
_print_name
=
(
"Binomial"
,
"
\\
operatorname{Binomial}"
)
_print_name
=
(
"Binomial"
,
"
\\
operatorname{Binomial}"
)
...
@@ -1645,9 +1601,8 @@ class NegBinomialRV(ScipyRandomVariable):
...
@@ -1645,9 +1601,8 @@ class NegBinomialRV(ScipyRandomVariable):
"""
"""
name
=
"nbinom"
name
=
"negative_binomial"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"int64"
dtype
=
"int64"
_print_name
=
(
"NegativeBinomial"
,
"
\\
operatorname{NegativeBinomial}"
)
_print_name
=
(
"NegativeBinomial"
,
"
\\
operatorname{NegativeBinomial}"
)
...
@@ -1702,8 +1657,7 @@ class BetaBinomialRV(ScipyRandomVariable):
...
@@ -1702,8 +1657,7 @@ class BetaBinomialRV(ScipyRandomVariable):
"""
"""
name
=
"beta_binomial"
name
=
"beta_binomial"
ndim_supp
=
0
signature
=
"(),(),()->()"
ndims_params
=
[
0
,
0
,
0
]
dtype
=
"int64"
dtype
=
"int64"
_print_name
=
(
"BetaBinomial"
,
"
\\
operatorname{BetaBinomial}"
)
_print_name
=
(
"BetaBinomial"
,
"
\\
operatorname{BetaBinomial}"
)
...
@@ -1754,8 +1708,7 @@ class GenGammaRV(ScipyRandomVariable):
...
@@ -1754,8 +1708,7 @@ class GenGammaRV(ScipyRandomVariable):
"""
"""
name
=
"gengamma"
name
=
"gengamma"
ndim_supp
=
0
signature
=
"(),(),()->()"
ndims_params
=
[
0
,
0
,
0
]
dtype
=
"floatX"
dtype
=
"floatX"
_print_name
=
(
"GeneralizedGamma"
,
"
\\
operatorname{GeneralizedGamma}"
)
_print_name
=
(
"GeneralizedGamma"
,
"
\\
operatorname{GeneralizedGamma}"
)
...
@@ -1817,8 +1770,7 @@ class MultinomialRV(RandomVariable):
...
@@ -1817,8 +1770,7 @@ class MultinomialRV(RandomVariable):
"""
"""
name
=
"multinomial"
name
=
"multinomial"
ndim_supp
=
1
signature
=
"(),(p)->(p)"
ndims_params
=
[
0
,
1
]
dtype
=
"int64"
dtype
=
"int64"
_print_name
=
(
"Multinomial"
,
"
\\
operatorname{Multinomial}"
)
_print_name
=
(
"Multinomial"
,
"
\\
operatorname{Multinomial}"
)
...
@@ -1845,14 +1797,6 @@ class MultinomialRV(RandomVariable):
...
@@ -1845,14 +1797,6 @@ class MultinomialRV(RandomVariable):
"""
"""
return
super
()
.
__call__
(
n
,
p
,
size
=
size
,
**
kwargs
)
return
super
()
.
__call__
(
n
,
p
,
size
=
size
,
**
kwargs
)
def
_supp_shape_from_params
(
self
,
dist_params
,
param_shapes
=
None
):
return
supp_shape_from_ref_param_shape
(
ndim_supp
=
self
.
ndim_supp
,
dist_params
=
dist_params
,
param_shapes
=
param_shapes
,
ref_param_idx
=
1
,
)
@classmethod
@classmethod
def
rng_fn
(
cls
,
rng
,
n
,
p
,
size
):
def
rng_fn
(
cls
,
rng
,
n
,
p
,
size
):
if
n
.
ndim
>
0
or
p
.
ndim
>
1
:
if
n
.
ndim
>
0
or
p
.
ndim
>
1
:
...
@@ -1862,7 +1806,7 @@ class MultinomialRV(RandomVariable):
...
@@ -1862,7 +1806,7 @@ class MultinomialRV(RandomVariable):
n
=
np
.
broadcast_to
(
n
,
size
)
n
=
np
.
broadcast_to
(
n
,
size
)
p
=
np
.
broadcast_to
(
p
,
size
+
p
.
shape
[
-
1
:])
p
=
np
.
broadcast_to
(
p
,
size
+
p
.
shape
[
-
1
:])
else
:
else
:
n
,
p
=
broadcast_params
([
n
,
p
],
cls
.
ndims_params
)
n
,
p
=
broadcast_params
([
n
,
p
],
[
0
,
1
]
)
res
=
np
.
empty
(
p
.
shape
,
dtype
=
cls
.
dtype
)
res
=
np
.
empty
(
p
.
shape
,
dtype
=
cls
.
dtype
)
for
idx
in
np
.
ndindex
(
p
.
shape
[:
-
1
]):
for
idx
in
np
.
ndindex
(
p
.
shape
[:
-
1
]):
...
@@ -1892,8 +1836,7 @@ class CategoricalRV(RandomVariable):
...
@@ -1892,8 +1836,7 @@ class CategoricalRV(RandomVariable):
"""
"""
name
=
"categorical"
name
=
"categorical"
ndim_supp
=
0
signature
=
"(p)->()"
ndims_params
=
[
1
]
dtype
=
"int64"
dtype
=
"int64"
_print_name
=
(
"Categorical"
,
"
\\
operatorname{Categorical}"
)
_print_name
=
(
"Categorical"
,
"
\\
operatorname{Categorical}"
)
...
@@ -1948,8 +1891,7 @@ class RandIntRV(RandomVariable):
...
@@ -1948,8 +1891,7 @@ class RandIntRV(RandomVariable):
"""
"""
name
=
"randint"
name
=
"randint"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"int64"
dtype
=
"int64"
_print_name
=
(
"randint"
,
"
\\
operatorname{randint}"
)
_print_name
=
(
"randint"
,
"
\\
operatorname{randint}"
)
...
@@ -2001,8 +1943,7 @@ class IntegersRV(RandomVariable):
...
@@ -2001,8 +1943,7 @@ class IntegersRV(RandomVariable):
"""
"""
name
=
"integers"
name
=
"integers"
ndim_supp
=
0
signature
=
"(),()->()"
ndims_params
=
[
0
,
0
]
dtype
=
"int64"
dtype
=
"int64"
_print_name
=
(
"integers"
,
"
\\
operatorname{integers}"
)
_print_name
=
(
"integers"
,
"
\\
operatorname{integers}"
)
...
@@ -2174,17 +2115,23 @@ def choice(a, size=None, replace=True, p=None, rng=None):
...
@@ -2174,17 +2115,23 @@ def choice(a, size=None, replace=True, p=None, rng=None):
a_ndim
=
a
.
type
.
ndim
a_ndim
=
a
.
type
.
ndim
dtype
=
a
.
type
.
dtype
dtype
=
a
.
type
.
dtype
a_dims
=
[
f
"a{i}"
for
i
in
range
(
a_ndim
)]
a_sig
=
","
.
join
(
a_dims
)
idx_dims
=
[
f
"s{i}"
for
i
in
range
(
core_shape_length
)]
if
a_ndim
==
0
:
p_sig
=
"a"
out_dims
=
idx_dims
else
:
p_sig
=
a_dims
[
0
]
out_dims
=
idx_dims
+
a_dims
[
1
:]
out_sig
=
","
.
join
(
out_dims
)
if
p
is
None
:
if
p
is
None
:
ndims_params
=
[
a_ndim
,
1
]
signature
=
f
"({a_sig}),({core_shape_length})->({out_sig})"
else
:
else
:
ndims_params
=
[
a_ndim
,
1
,
1
]
signature
=
f
"({a_sig}),({p_sig}),({core_shape_length})->({out_sig})"
ndim_supp
=
max
(
a_ndim
-
1
,
0
)
+
core_shape_length
op
=
ChoiceWithoutReplacement
(
op
=
ChoiceWithoutReplacement
(
signature
=
signature
,
dtype
=
dtype
)
ndim_supp
=
ndim_supp
,
ndims_params
=
ndims_params
,
dtype
=
dtype
,
)
params
=
(
a
,
core_shape
)
if
p
is
None
else
(
a
,
p
,
core_shape
)
params
=
(
a
,
core_shape
)
if
p
is
None
else
(
a
,
p
,
core_shape
)
return
op
(
*
params
,
size
=
None
,
rng
=
rng
)
return
op
(
*
params
,
size
=
None
,
rng
=
rng
)
...
@@ -2247,10 +2194,12 @@ def permutation(x, **kwargs):
...
@@ -2247,10 +2194,12 @@ def permutation(x, **kwargs):
x_dtype
=
x
.
type
.
dtype
x_dtype
=
x
.
type
.
dtype
# PermutationRV has a signature () -> (x) if x is a scalar
# PermutationRV has a signature () -> (x) if x is a scalar
# and (*x) -> (*x) otherwise, with has many entries as the dimensionsality of x
# and (*x) -> (*x) otherwise, with has many entries as the dimensionsality of x
ndim_supp
=
max
(
x_ndim
,
1
)
if
x_ndim
==
0
:
return
PermutationRV
(
ndim_supp
=
ndim_supp
,
ndims_params
=
[
x_ndim
],
dtype
=
x_dtype
)(
signature
=
"()->(x)"
x
,
**
kwargs
else
:
)
arg_sig
=
","
.
join
(
f
"x{i}"
for
i
in
range
(
x_ndim
))
signature
=
f
"({arg_sig})->({arg_sig})"
return
PermutationRV
(
signature
=
signature
,
dtype
=
x_dtype
)(
x
,
**
kwargs
)
__all__
=
[
__all__
=
[
...
...
pytensor/tensor/random/op.py
浏览文件 @
3e9c6a4f
import
warnings
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
copy
import
copy
from
copy
import
copy
from
typing
import
cast
from
typing
import
cast
...
@@ -28,6 +29,7 @@ from pytensor.tensor.random.utils import (
...
@@ -28,6 +29,7 @@ from pytensor.tensor.random.utils import (
from
pytensor.tensor.shape
import
shape_tuple
from
pytensor.tensor.shape
import
shape_tuple
from
pytensor.tensor.type
import
TensorType
,
all_dtypes
from
pytensor.tensor.type
import
TensorType
,
all_dtypes
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.utils
import
_parse_gufunc_signature
,
safe_signature
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
...
@@ -42,7 +44,7 @@ class RandomVariable(Op):
...
@@ -42,7 +44,7 @@ class RandomVariable(Op):
_output_type_depends_on_input_value
=
True
_output_type_depends_on_input_value
=
True
__props__
=
(
"name"
,
"
ndim_supp"
,
"ndims_params
"
,
"dtype"
,
"inplace"
)
__props__
=
(
"name"
,
"
signature
"
,
"dtype"
,
"inplace"
)
default_output
=
1
default_output
=
1
def
__init__
(
def
__init__
(
...
@@ -50,8 +52,9 @@ class RandomVariable(Op):
...
@@ -50,8 +52,9 @@ class RandomVariable(Op):
name
=
None
,
name
=
None
,
ndim_supp
=
None
,
ndim_supp
=
None
,
ndims_params
=
None
,
ndims_params
=
None
,
dtype
=
None
,
dtype
:
str
|
None
=
None
,
inplace
=
None
,
inplace
=
None
,
signature
:
str
|
None
=
None
,
):
):
"""Create a random variable `Op`.
"""Create a random variable `Op`.
...
@@ -59,44 +62,63 @@ class RandomVariable(Op):
...
@@ -59,44 +62,63 @@ class RandomVariable(Op):
----------
----------
name: str
name: str
The `Op`'s display name.
The `Op`'s display name.
ndim_supp: int
signature: str
Total number of dimensions for a single draw of the random variable
Numpy-like vectorized signature of the random variable.
(e.g. a multivariate normal draw is 1D, so ``ndim_supp = 1``).
ndims_params: list of int
Number of dimensions for each distribution parameter when the
parameters only specify a single drawn of the random variable
(e.g. a multivariate normal's mean is 1D and covariance is 2D, so
``ndims_params = [1, 2]``).
dtype: str (optional)
dtype: str (optional)
The dtype of the sampled output. If the value ``"floatX"`` is
The dtype of the sampled output. If the value ``"floatX"`` is
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
``None`` (the default), the `dtype` keyword must be set when
``None`` (the default), the `dtype` keyword must be set when
`RandomVariable.make_node` is called.
`RandomVariable.make_node` is called.
inplace: boolean (optional)
inplace: boolean (optional)
Determine whether or not the underlying rng state is updated
Determine whether the underlying rng state is mutated or copied.
in-place or not (i.e. copied).
"""
"""
super
()
.
__init__
()
super
()
.
__init__
()
self
.
name
=
name
or
getattr
(
self
,
"name"
)
self
.
name
=
name
or
getattr
(
self
,
"name"
)
self
.
ndim_supp
=
(
ndim_supp
if
ndim_supp
is
not
None
else
getattr
(
self
,
"ndim_supp"
)
ndim_supp
=
(
ndim_supp
if
ndim_supp
is
not
None
else
getattr
(
self
,
"ndim_supp"
,
None
)
)
)
self
.
ndims_params
=
(
if
ndim_supp
is
not
None
:
ndims_params
if
ndims_params
is
not
None
else
getattr
(
self
,
"ndims_params"
)
warnings
.
warn
(
"ndim_supp is deprecated. Provide signature instead."
,
FutureWarning
)
self
.
ndim_supp
=
ndim_supp
ndims_params
=
(
ndims_params
if
ndims_params
is
not
None
else
getattr
(
self
,
"ndims_params"
,
None
)
)
)
if
ndims_params
is
not
None
:
warnings
.
warn
(
"ndims_params is deprecated. Provide signature instead."
,
FutureWarning
)
if
not
isinstance
(
ndims_params
,
Sequence
):
raise
TypeError
(
"Parameter ndims_params must be sequence type."
)
self
.
ndims_params
=
tuple
(
ndims_params
)
self
.
signature
=
signature
or
getattr
(
self
,
"signature"
,
None
)
if
self
.
signature
is
not
None
:
# Assume a single output. Several methods need to be updated to handle multiple outputs.
self
.
inputs_sig
,
[
self
.
output_sig
]
=
_parse_gufunc_signature
(
self
.
signature
)
self
.
ndims_params
=
[
len
(
input_sig
)
for
input_sig
in
self
.
inputs_sig
]
self
.
ndim_supp
=
len
(
self
.
output_sig
)
else
:
if
(
getattr
(
self
,
"ndim_supp"
,
None
)
is
None
or
getattr
(
self
,
"ndims_params"
,
None
)
is
None
):
raise
ValueError
(
"signature must be provided"
)
else
:
self
.
signature
=
safe_signature
(
self
.
ndims_params
,
[
self
.
ndim_supp
])
self
.
dtype
=
dtype
or
getattr
(
self
,
"dtype"
,
None
)
self
.
dtype
=
dtype
or
getattr
(
self
,
"dtype"
,
None
)
self
.
inplace
=
(
self
.
inplace
=
(
inplace
if
inplace
is
not
None
else
getattr
(
self
,
"inplace"
,
False
)
inplace
if
inplace
is
not
None
else
getattr
(
self
,
"inplace"
,
False
)
)
)
if
not
isinstance
(
self
.
ndims_params
,
Sequence
):
raise
TypeError
(
"Parameter ndims_params must be sequence type."
)
self
.
ndims_params
=
tuple
(
self
.
ndims_params
)
if
self
.
inplace
:
if
self
.
inplace
:
self
.
destroy_map
=
{
0
:
[
0
]}
self
.
destroy_map
=
{
0
:
[
0
]}
...
@@ -120,8 +142,31 @@ class RandomVariable(Op):
...
@@ -120,8 +142,31 @@ class RandomVariable(Op):
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
values (not shapes) of some parameters. For instance, a `gaussian_random_walk(steps, size=(2,))`,
might have `support_shape=(steps,)`.
might have `support_shape=(steps,)`.
"""
"""
if
self
.
signature
is
not
None
:
# Signature could indicate fixed numerical shapes
# As per https://numpy.org/neps/nep-0020-gufunc-signature-enhancement.html
output_sig
=
self
.
output_sig
core_out_shape
=
{
dim
:
int
(
dim
)
if
str
.
isnumeric
(
dim
)
else
None
for
dim
in
self
.
output_sig
}
# Try to infer missing support dims from signature of params
for
param
,
param_sig
,
ndim_params
in
zip
(
dist_params
,
self
.
inputs_sig
,
self
.
ndims_params
):
if
ndim_params
==
0
:
continue
for
param_dim
,
dim
in
zip
(
param
.
shape
[
-
ndim_params
:],
param_sig
):
if
dim
in
core_out_shape
and
core_out_shape
[
dim
]
is
None
:
core_out_shape
[
dim
]
=
param_dim
if
all
(
dim
is
not
None
for
dim
in
core_out_shape
.
values
()):
# We have all we need
return
[
core_out_shape
[
dim
]
for
dim
in
output_sig
]
raise
NotImplementedError
(
raise
NotImplementedError
(
"`_supp_shape_from_params` must be implemented for multivariate RVs"
"`_supp_shape_from_params` must be implemented for multivariate RVs "
"when signature is not sufficient to infer the support shape"
)
)
def
rng_fn
(
self
,
rng
,
*
args
,
**
kwargs
)
->
int
|
float
|
np
.
ndarray
:
def
rng_fn
(
self
,
rng
,
*
args
,
**
kwargs
)
->
int
|
float
|
np
.
ndarray
:
...
@@ -129,7 +174,24 @@ class RandomVariable(Op):
...
@@ -129,7 +174,24 @@ class RandomVariable(Op):
return
getattr
(
rng
,
self
.
name
)(
*
args
,
**
kwargs
)
return
getattr
(
rng
,
self
.
name
)(
*
args
,
**
kwargs
)
def
__str__
(
self
):
def
__str__
(
self
):
props_str
=
", "
.
join
(
f
"{getattr(self, prop)}"
for
prop
in
self
.
__props__
[
1
:])
# Only show signature from core props
if
signature
:
=
self
.
signature
:
# inp, out = signature.split("->")
# extended_signature = f"[rng],[size],{inp}->[rng],{out}"
# core_props = [extended_signature]
core_props
=
[
f
'"{signature}"'
]
else
:
# Far back compat
core_props
=
[
str
(
self
.
ndim_supp
),
str
(
self
.
ndims_params
)]
# Add any extra props that the subclass may have
extra_props
=
[
str
(
getattr
(
self
,
prop
))
for
prop
in
self
.
__props__
if
prop
not
in
RandomVariable
.
__props__
]
props_str
=
", "
.
join
(
core_props
+
extra_props
)
return
f
"{self.name}_rv{{{props_str}}}"
return
f
"{self.name}_rv{{{props_str}}}"
def
_infer_shape
(
def
_infer_shape
(
...
@@ -298,11 +360,11 @@ class RandomVariable(Op):
...
@@ -298,11 +360,11 @@ class RandomVariable(Op):
dtype_idx
=
constant
(
all_dtypes
.
index
(
dtype
),
dtype
=
"int64"
)
dtype_idx
=
constant
(
all_dtypes
.
index
(
dtype
),
dtype
=
"int64"
)
else
:
else
:
dtype_idx
=
constant
(
dtype
,
dtype
=
"int64"
)
dtype_idx
=
constant
(
dtype
,
dtype
=
"int64"
)
dtype
=
all_dtypes
[
dtype_idx
.
data
]
outtype
=
TensorType
(
dtype
=
dtype
,
shape
=
static_shape
)
dtype
=
all_dtypes
[
dtype_idx
.
data
]
out_var
=
outtype
()
inputs
=
(
rng
,
size
,
dtype_idx
,
*
dist_params
)
inputs
=
(
rng
,
size
,
dtype_idx
,
*
dist_params
)
out_var
=
TensorType
(
dtype
=
dtype
,
shape
=
static_shape
)()
outputs
=
(
rng
.
type
(),
out_var
)
outputs
=
(
rng
.
type
(),
out_var
)
return
Apply
(
self
,
inputs
,
outputs
)
return
Apply
(
self
,
inputs
,
outputs
)
...
@@ -395,9 +457,8 @@ def vectorize_random_variable(
...
@@ -395,9 +457,8 @@ def vectorize_random_variable(
# We extend it to accommodate the new input batch dimensions.
# We extend it to accommodate the new input batch dimensions.
# Otherwise, we assume the new size already has the right values
# Otherwise, we assume the new size already has the right values
# Need to make parameters implicit broadcasting explicit
original_dist_params
=
op
.
dist_params
(
node
)
original_dist_params
=
node
.
inputs
[
3
:]
old_size
=
op
.
size_param
(
node
)
old_size
=
node
.
inputs
[
1
]
len_old_size
=
get_vector_length
(
old_size
)
len_old_size
=
get_vector_length
(
old_size
)
original_expanded_dist_params
=
explicit_expand_dims
(
original_expanded_dist_params
=
explicit_expand_dims
(
...
...
pytensor/tensor/random/rewriting/jax.py
浏览文件 @
3e9c6a4f
import
re
from
pytensor.compile
import
optdb
from
pytensor.compile
import
optdb
from
pytensor.graph.rewriting.basic
import
in2out
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
in2out
,
node_rewriter
from
pytensor.graph.rewriting.db
import
SequenceDB
from
pytensor.graph.rewriting.db
import
SequenceDB
...
@@ -164,9 +166,9 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
...
@@ -164,9 +166,9 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
a_vector_param
=
arange
(
a_scalar_param
)
a_vector_param
=
arange
(
a_scalar_param
)
new_props_dict
=
op
.
_props_dict
()
.
copy
()
new_props_dict
=
op
.
_props_dict
()
.
copy
()
new_ndims_params
=
list
(
op
.
ndims_params
)
# Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)"
new_ndims_params
[
0
]
+=
1
# I.e., we substitute the first `()` by `(a)`
new_props_dict
[
"
ndims_params"
]
=
new_ndims_params
new_props_dict
[
"
signature"
]
=
re
.
sub
(
r"\(\)"
,
"(a)"
,
op
.
signature
,
1
)
new_op
=
type
(
op
)(
**
new_props_dict
)
new_op
=
type
(
op
)(
**
new_props_dict
)
return
new_op
.
make_node
(
rng
,
size
,
dtype
,
a_vector_param
,
*
other_params
)
.
outputs
return
new_op
.
make_node
(
rng
,
size
,
dtype
,
a_vector_param
,
*
other_params
)
.
outputs
...
...
pytensor/tensor/random/utils.py
浏览文件 @
3e9c6a4f
...
@@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params):
...
@@ -123,7 +123,7 @@ def broadcast_params(params, ndims_params):
def
explicit_expand_dims
(
def
explicit_expand_dims
(
params
:
Sequence
[
TensorVariable
],
params
:
Sequence
[
TensorVariable
],
ndim_params
:
tupl
e
[
int
],
ndim_params
:
Sequenc
e
[
int
],
size_length
:
int
=
0
,
size_length
:
int
=
0
,
)
->
list
[
TensorVariable
]:
)
->
list
[
TensorVariable
]:
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
"""Introduce explicit expand_dims in RV parameters that are implicitly broadcasted together and/or by size."""
...
@@ -137,7 +137,7 @@ def explicit_expand_dims(
...
@@ -137,7 +137,7 @@ def explicit_expand_dims(
# See: https://github.com/pymc-devs/pytensor/issues/568
# See: https://github.com/pymc-devs/pytensor/issues/568
max_batch_dims
=
size_length
max_batch_dims
=
size_length
else
:
else
:
max_batch_dims
=
max
(
batch_dims
)
max_batch_dims
=
max
(
batch_dims
,
default
=
0
)
new_params
=
[]
new_params
=
[]
for
new_param
,
batch_dim
in
zip
(
params
,
batch_dims
):
for
new_param
,
batch_dim
in
zip
(
params
,
batch_dims
):
...
@@ -354,6 +354,11 @@ def supp_shape_from_ref_param_shape(
...
@@ -354,6 +354,11 @@ def supp_shape_from_ref_param_shape(
out: tuple
out: tuple
Representing the support shape for a `RandomVariable` with the given `dist_params`.
Representing the support shape for a `RandomVariable` with the given `dist_params`.
Notes
_____
This helper is no longer necessary when using signatures in `RandomVariable` subclasses.
"""
"""
if
ndim_supp
<=
0
:
if
ndim_supp
<=
0
:
raise
ValueError
(
"ndim_supp must be greater than 0"
)
raise
ValueError
(
"ndim_supp must be greater than 0"
)
...
...
pytensor/tensor/utils.py
浏览文件 @
3e9c6a4f
...
@@ -169,7 +169,8 @@ _DIMENSION_NAME = r"\w+"
...
@@ -169,7 +169,8 @@ _DIMENSION_NAME = r"\w+"
_CORE_DIMENSION_LIST
=
f
"(?:{_DIMENSION_NAME}(?:,{_DIMENSION_NAME})*)?"
_CORE_DIMENSION_LIST
=
f
"(?:{_DIMENSION_NAME}(?:,{_DIMENSION_NAME})*)?"
_ARGUMENT
=
rf
"
\
({_CORE_DIMENSION_LIST}
\
)"
_ARGUMENT
=
rf
"
\
({_CORE_DIMENSION_LIST}
\
)"
_ARGUMENT_LIST
=
f
"{_ARGUMENT}(?:,{_ARGUMENT})*"
_ARGUMENT_LIST
=
f
"{_ARGUMENT}(?:,{_ARGUMENT})*"
_SIGNATURE
=
f
"^{_ARGUMENT_LIST}->{_ARGUMENT_LIST}$"
# Allow no inputs
_SIGNATURE
=
f
"^(?:{_ARGUMENT_LIST})?->{_ARGUMENT_LIST}$"
def
_parse_gufunc_signature
(
def
_parse_gufunc_signature
(
...
@@ -200,6 +201,8 @@ def _parse_gufunc_signature(
...
@@ -200,6 +201,8 @@ def _parse_gufunc_signature(
tuple
(
re
.
findall
(
_DIMENSION_NAME
,
arg
))
tuple
(
re
.
findall
(
_DIMENSION_NAME
,
arg
))
for
arg
in
re
.
findall
(
_ARGUMENT
,
arg_list
)
for
arg
in
re
.
findall
(
_ARGUMENT
,
arg_list
)
]
]
if
arg_list
# ignore no inputs
else
[]
for
arg_list
in
signature
.
split
(
"->"
)
for
arg_list
in
signature
.
split
(
"->"
)
)
)
...
...
tests/link/jax/test_random.py
浏览文件 @
3e9c6a4f
...
@@ -771,8 +771,7 @@ def test_random_unimplemented():
...
@@ -771,8 +771,7 @@ def test_random_unimplemented():
class
NonExistentRV
(
RandomVariable
):
class
NonExistentRV
(
RandomVariable
):
name
=
"non-existent"
name
=
"non-existent"
ndim_supp
=
0
signature
=
"->()"
ndims_params
=
[]
dtype
=
"floatX"
dtype
=
"floatX"
def
__call__
(
self
,
size
=
None
,
**
kwargs
):
def
__call__
(
self
,
size
=
None
,
**
kwargs
):
...
@@ -798,8 +797,7 @@ def test_random_custom_implementation():
...
@@ -798,8 +797,7 @@ def test_random_custom_implementation():
class
CustomRV
(
RandomVariable
):
class
CustomRV
(
RandomVariable
):
name
=
"non-existent"
name
=
"non-existent"
ndim_supp
=
0
signature
=
"->()"
ndims_params
=
[]
dtype
=
"floatX"
dtype
=
"floatX"
def
__call__
(
self
,
size
=
None
,
**
kwargs
):
def
__call__
(
self
,
size
=
None
,
**
kwargs
):
...
...
tests/tensor/random/rewriting/test_basic.py
浏览文件 @
3e9c6a4f
...
@@ -74,52 +74,28 @@ def apply_local_rewrite_to_rv(
...
@@ -74,52 +74,28 @@ def apply_local_rewrite_to_rv(
return
new_out
,
f_inputs
,
dist_st
,
f_rewritten
return
new_out
,
f_inputs
,
dist_st
,
f_rewritten
def
test_inplace_rewrites
():
class
TestRVExpraProps
(
RandomVariable
):
out
=
normal
(
0
,
1
)
name
=
"test"
out
.
owner
.
inputs
[
0
]
.
default_update
=
out
.
owner
.
outputs
[
0
]
signature
=
"()->()"
__props__
=
(
"name"
,
"signature"
,
"dtype"
,
"inplace"
,
"extra"
)
dtype
=
"floatX"
_print_name
=
(
"TestExtraProps"
,
"
\\
operatorname{TestExtra_props}"
)
assert
out
.
owner
.
op
.
inplace
is
False
def
__init__
(
self
,
extra
,
*
args
,
**
kwargs
):
self
.
extra
=
extra
super
()
.
__init__
(
*
args
,
**
kwargs
)
f
=
function
(
def
rng_fn
(
self
,
rng
,
dtype
,
sigma
,
size
):
[],
return
rng
.
normal
(
scale
=
sigma
,
size
=
size
)
out
,
mode
=
"FAST_RUN"
,
)
(
new_out
,
new_rng
)
=
f
.
maker
.
fgraph
.
outputs
assert
new_out
.
type
==
out
.
type
assert
isinstance
(
new_out
.
owner
.
op
,
type
(
out
.
owner
.
op
))
assert
new_out
.
owner
.
op
.
inplace
is
True
assert
all
(
np
.
array_equal
(
a
.
data
,
b
.
data
)
for
a
,
b
in
zip
(
new_out
.
owner
.
inputs
[
2
:],
out
.
owner
.
inputs
[
2
:])
)
assert
np
.
array_equal
(
new_out
.
owner
.
inputs
[
1
]
.
data
,
[])
def
test_inplace_rewrites_extra_props
():
class
Test
(
RandomVariable
):
name
=
"test"
ndim_supp
=
0
ndims_params
=
[
0
]
__props__
=
(
"name"
,
"ndim_supp"
,
"ndims_params"
,
"dtype"
,
"inplace"
,
"extra"
)
dtype
=
"floatX"
_print_name
=
(
"Test"
,
"
\\
operatorname{Test}"
)
def
__init__
(
self
,
extra
,
*
args
,
**
kwargs
):
self
.
extra
=
extra
super
()
.
__init__
(
*
args
,
**
kwargs
)
def
make_node
(
self
,
rng
,
size
,
dtype
,
sigma
):
return
super
()
.
make_node
(
rng
,
size
,
dtype
,
sigma
)
def
rng_fn
(
self
,
rng
,
sigma
,
size
):
return
rng
.
normal
(
scale
=
sigma
,
size
=
size
)
out
=
Test
(
extra
=
"some value"
)(
1
)
out
.
owner
.
inputs
[
0
]
.
default_update
=
out
.
owner
.
outputs
[
0
]
assert
out
.
owner
.
op
.
inplace
is
False
@pytest.mark.parametrize
(
"rv_op"
,
[
normal
,
TestRVExpraProps
(
extra
=
"some value"
)])
def
test_inplace_rewrites
(
rv_op
):
out
=
rv_op
(
np
.
e
)
node
=
out
.
owner
op
=
node
.
op
node
.
inputs
[
0
]
.
default_update
=
node
.
outputs
[
0
]
assert
op
.
inplace
is
False
f
=
function
(
f
=
function
(
[],
[],
...
@@ -129,9 +105,10 @@ def test_inplace_rewrites_extra_props():
...
@@ -129,9 +105,10 @@ def test_inplace_rewrites_extra_props():
(
new_out
,
new_rng
)
=
f
.
maker
.
fgraph
.
outputs
(
new_out
,
new_rng
)
=
f
.
maker
.
fgraph
.
outputs
assert
new_out
.
type
==
out
.
type
assert
new_out
.
type
==
out
.
type
assert
isinstance
(
new_out
.
owner
.
op
,
type
(
out
.
owner
.
op
))
new_node
=
new_out
.
owner
assert
new_out
.
owner
.
op
.
inplace
is
True
new_op
=
new_node
.
op
assert
new_out
.
owner
.
op
.
extra
==
out
.
owner
.
op
.
extra
assert
isinstance
(
new_op
,
type
(
op
))
assert
new_op
.
_props_dict
()
==
(
op
.
_props_dict
()
|
{
"inplace"
:
True
})
assert
all
(
assert
all
(
np
.
array_equal
(
a
.
data
,
b
.
data
)
np
.
array_equal
(
a
.
data
,
b
.
data
)
for
a
,
b
in
zip
(
new_out
.
owner
.
inputs
[
2
:],
out
.
owner
.
inputs
[
2
:])
for
a
,
b
in
zip
(
new_out
.
owner
.
inputs
[
2
:],
out
.
owner
.
inputs
[
2
:])
...
...
tests/tensor/random/test_basic.py
浏览文件 @
3e9c6a4f
...
@@ -1463,11 +1463,8 @@ def batched_unweighted_choice_without_replacement_tester(
...
@@ -1463,11 +1463,8 @@ def batched_unweighted_choice_without_replacement_tester(
rng
=
shared
(
rng_ctor
())
rng
=
shared
(
rng_ctor
())
# Batched a implicit size
# Batched a implicit size
a_core_ndim
=
2
core_shape_len
=
1
rv_op
=
ChoiceWithoutReplacement
(
rv_op
=
ChoiceWithoutReplacement
(
ndim_supp
=
max
(
a_core_ndim
-
1
,
0
)
+
core_shape_len
,
signature
=
"(a0,a1),(1)->(s0,a1)"
,
ndims_params
=
[
a_core_ndim
,
core_shape_len
],
dtype
=
"int64"
,
dtype
=
"int64"
,
)
)
...
@@ -1483,11 +1480,8 @@ def batched_unweighted_choice_without_replacement_tester(
...
@@ -1483,11 +1480,8 @@ def batched_unweighted_choice_without_replacement_tester(
assert
np
.
all
((
draw
>=
i
*
10
)
&
(
draw
<
(
i
+
1
)
*
10
))
assert
np
.
all
((
draw
>=
i
*
10
)
&
(
draw
<
(
i
+
1
)
*
10
))
# Explicit size broadcasts beyond a
# Explicit size broadcasts beyond a
a_core_ndim
=
2
core_shape_len
=
2
rv_op
=
ChoiceWithoutReplacement
(
rv_op
=
ChoiceWithoutReplacement
(
ndim_supp
=
max
(
a_core_ndim
-
1
,
0
)
+
core_shape_len
,
signature
=
"(a0,a1),(2)->(s0,s1,a1)"
,
ndims_params
=
[
a_core_ndim
,
len
(
core_shape
)],
dtype
=
"int64"
,
dtype
=
"int64"
,
)
)
...
@@ -1515,12 +1509,8 @@ def batched_weighted_choice_without_replacement_tester(
...
@@ -1515,12 +1509,8 @@ def batched_weighted_choice_without_replacement_tester(
"""
"""
rng
=
shared
(
rng_ctor
())
rng
=
shared
(
rng_ctor
())
# 3 ndims params indicates p is passed
a_core_ndim
=
2
core_shape_len
=
1
rv_op
=
ChoiceWithoutReplacement
(
rv_op
=
ChoiceWithoutReplacement
(
ndim_supp
=
max
(
a_core_ndim
-
1
,
0
)
+
core_shape_len
,
signature
=
"(a0,a1),(a0),(1)->(s0,a1)"
,
ndims_params
=
[
a_core_ndim
,
1
,
1
],
dtype
=
"int64"
,
dtype
=
"int64"
,
)
)
...
@@ -1540,11 +1530,8 @@ def batched_weighted_choice_without_replacement_tester(
...
@@ -1540,11 +1530,8 @@ def batched_weighted_choice_without_replacement_tester(
# p and a are batched
# p and a are batched
# Test implicit arange
# Test implicit arange
a_core_ndim
=
0
core_shape_len
=
2
rv_op
=
ChoiceWithoutReplacement
(
rv_op
=
ChoiceWithoutReplacement
(
ndim_supp
=
max
(
a_core_ndim
-
1
,
0
)
+
core_shape_len
,
signature
=
"(),(a),(2)->(s0,s1)"
,
ndims_params
=
[
a_core_ndim
,
1
,
1
],
dtype
=
"int64"
,
dtype
=
"int64"
,
)
)
a
=
6
a
=
6
...
@@ -1566,11 +1553,8 @@ def batched_weighted_choice_without_replacement_tester(
...
@@ -1566,11 +1553,8 @@ def batched_weighted_choice_without_replacement_tester(
assert
set
(
draw
)
==
set
(
range
(
i
,
6
,
2
))
assert
set
(
draw
)
==
set
(
range
(
i
,
6
,
2
))
# Size broadcasts beyond a
# Size broadcasts beyond a
a_core_ndim
=
2
core_shape_len
=
1
rv_op
=
ChoiceWithoutReplacement
(
rv_op
=
ChoiceWithoutReplacement
(
ndim_supp
=
max
(
a_core_ndim
-
1
,
0
)
+
core_shape_len
,
signature
=
"(a0,a1),(a0),(1)->(s0,a1)"
,
ndims_params
=
[
a_core_ndim
,
1
,
1
],
dtype
=
"int64"
,
dtype
=
"int64"
,
)
)
a
=
np
.
arange
(
4
*
5
*
2
)
.
reshape
((
4
,
5
,
2
))
a
=
np
.
arange
(
4
*
5
*
2
)
.
reshape
((
4
,
5
,
2
))
...
...
tests/tensor/random/test_op.py
浏览文件 @
3e9c6a4f
...
@@ -23,14 +23,13 @@ def test_RandomVariable_basics(strict_test_value_flags):
...
@@ -23,14 +23,13 @@ def test_RandomVariable_basics(strict_test_value_flags):
str_res
=
str
(
str_res
=
str
(
RandomVariable
(
RandomVariable
(
"normal"
,
"normal"
,
0
,
signature
=
"(),()->()"
,
[
0
,
0
],
dtype
=
"float32"
,
"float32"
,
inplace
=
False
,
inplace
=
True
,
)
)
)
)
assert
str_res
==
"normal_rv{0, (0, 0), float32, True}"
assert
str_res
==
'normal_rv{"(),()->()"}'
# `ndims_params` should be a `Sequence` type
# `ndims_params` should be a `Sequence` type
with
pytest
.
raises
(
TypeError
,
match
=
"^Parameter ndims_params*"
):
with
pytest
.
raises
(
TypeError
,
match
=
"^Parameter ndims_params*"
):
...
@@ -64,9 +63,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
...
@@ -64,9 +63,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
# Confirm that `inplace` works
# Confirm that `inplace` works
rv
=
RandomVariable
(
rv
=
RandomVariable
(
"normal"
,
"normal"
,
0
,
signature
=
"(),()->()"
,
[
0
,
0
],
"normal"
,
inplace
=
True
,
inplace
=
True
,
)
)
...
@@ -74,7 +71,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
...
@@ -74,7 +71,7 @@ def test_RandomVariable_basics(strict_test_value_flags):
assert
rv
.
destroy_map
==
{
0
:
[
0
]}
assert
rv
.
destroy_map
==
{
0
:
[
0
]}
# A no-params `RandomVariable`
# A no-params `RandomVariable`
rv
=
RandomVariable
(
name
=
"test_rv"
,
ndim_supp
=
0
,
ndims_params
=
()
)
rv
=
RandomVariable
(
name
=
"test_rv"
,
signature
=
"->()"
)
with
pytest
.
raises
(
TypeError
):
with
pytest
.
raises
(
TypeError
):
rv
.
make_node
(
rng
=
1
)
rv
.
make_node
(
rng
=
1
)
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论