Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
7886cf83
提交
7886cf83
authored
6月 20, 2025
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
6月 21, 2025
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement XTensorVariable version of RandomVariables
上级
33d04c36
全部展开
显示空白字符变更
内嵌
并排
正在显示
9 个修改的文件
包含
415 行增加
和
15 行删除
+415
-15
basic.py
pytensor/tensor/random/basic.py
+2
-2
op.py
pytensor/tensor/random/op.py
+18
-1
basic.py
pytensor/tensor/rewriting/basic.py
+1
-3
utils.py
pytensor/tensor/utils.py
+29
-0
random.py
pytensor/xtensor/random.py
+168
-0
vectorization.py
pytensor/xtensor/rewriting/vectorization.py
+48
-1
shape.py
pytensor/xtensor/shape.py
+4
-8
vectorization.py
pytensor/xtensor/vectorization.py
+145
-0
test_random.py
tests/xtensor/test_random.py
+0
-0
没有找到文件。
pytensor/tensor/random/basic.py
浏览文件 @
7886cf83
...
@@ -1625,8 +1625,7 @@ class NegBinomialRV(ScipyRandomVariable):
...
@@ -1625,8 +1625,7 @@ class NegBinomialRV(ScipyRandomVariable):
return
stats
.
nbinom
.
rvs
(
n
,
p
,
size
=
size
,
random_state
=
rng
)
return
stats
.
nbinom
.
rvs
(
n
,
p
,
size
=
size
,
random_state
=
rng
)
nbinom
=
NegBinomialRV
()
nbinom
=
negative_binomial
=
NegBinomialRV
()
negative_binomial
=
NegBinomialRV
()
class
BetaBinomialRV
(
ScipyRandomVariable
):
class
BetaBinomialRV
(
ScipyRandomVariable
):
...
@@ -1808,6 +1807,7 @@ class MultinomialRV(RandomVariable):
...
@@ -1808,6 +1807,7 @@ class MultinomialRV(RandomVariable):
multinomial
=
MultinomialRV
()
multinomial
=
MultinomialRV
()
vsearchsorted
=
np
.
vectorize
(
np
.
searchsorted
,
otypes
=
[
int
],
signature
=
"(n),()->()"
)
vsearchsorted
=
np
.
vectorize
(
np
.
searchsorted
,
otypes
=
[
int
],
signature
=
"(n),()->()"
)
...
...
pytensor/tensor/random/op.py
浏览文件 @
7886cf83
import
abc
import
warnings
import
warnings
from
collections.abc
import
Sequence
from
collections.abc
import
Sequence
from
copy
import
deepcopy
from
copy
import
deepcopy
...
@@ -32,7 +33,20 @@ from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
...
@@ -32,7 +33,20 @@ from pytensor.tensor.utils import _parse_gufunc_signature, safe_signature
from
pytensor.tensor.variable
import
TensorVariable
from
pytensor.tensor.variable
import
TensorVariable
class
RandomVariable
(
Op
):
class
RNGConsumerOp
(
Op
):
"""Baseclass for Ops that consume RNGs."""
@abc.abstractmethod
def
update
(
self
,
node
:
Apply
)
->
dict
[
Variable
,
Variable
]:
"""Symbolic update expression for input RNG variables.
Returns a dictionary with the symbolic expressions required for correct updating
of RNG variables in repeated function evaluations.
"""
pass
class
RandomVariable
(
RNGConsumerOp
):
"""An `Op` that produces a sample from a random variable.
"""An `Op` that produces a sample from a random variable.
This is essentially `RandomFunction`, except that it removes the
This is essentially `RandomFunction`, except that it removes the
...
@@ -123,6 +137,9 @@ class RandomVariable(Op):
...
@@ -123,6 +137,9 @@ class RandomVariable(Op):
if
self
.
inplace
:
if
self
.
inplace
:
self
.
destroy_map
=
{
0
:
[
0
]}
self
.
destroy_map
=
{
0
:
[
0
]}
def
update
(
self
,
node
:
Apply
)
->
dict
[
Variable
,
Variable
]:
return
{
node
.
inputs
[
0
]:
node
.
outputs
[
0
]}
def
_supp_shape_from_params
(
self
,
dist_params
,
param_shapes
=
None
):
def
_supp_shape_from_params
(
self
,
dist_params
,
param_shapes
=
None
):
"""Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.
"""Determine the support shape of a multivariate `RandomVariable`'s output given its parameters.
...
...
pytensor/tensor/rewriting/basic.py
浏览文件 @
7886cf83
...
@@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node):
...
@@ -759,6 +759,7 @@ def local_remove_useless_assert(fgraph, node):
return
[
new_var
]
return
[
new_var
]
@register_infer_shape
@node_rewriter
([
Assert
])
@node_rewriter
([
Assert
])
def
local_remove_all_assert
(
fgraph
,
node
):
def
local_remove_all_assert
(
fgraph
,
node
):
r"""A rewrite that removes all `Assert`\s from a graph.
r"""A rewrite that removes all `Assert`\s from a graph.
...
@@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node):
...
@@ -768,9 +769,6 @@ def local_remove_all_assert(fgraph, node):
See the :ref:`unsafe` section.
See the :ref:`unsafe` section.
"""
"""
if
not
isinstance
(
node
.
op
,
Assert
):
return
return
[
node
.
inputs
[
0
]]
return
[
node
.
inputs
[
0
]]
...
...
pytensor/tensor/utils.py
浏览文件 @
7886cf83
...
@@ -9,6 +9,7 @@ from numpy import nditer
...
@@ -9,6 +9,7 @@ from numpy import nditer
import
pytensor
import
pytensor
from
pytensor.graph
import
FunctionGraph
,
Variable
from
pytensor.graph
import
FunctionGraph
,
Variable
from
pytensor.npy_2_compat
import
normalize_axis_tuple
from
pytensor.npy_2_compat
import
normalize_axis_tuple
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.utils
import
hash_from_code
from
pytensor.utils
import
hash_from_code
...
@@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]):
...
@@ -256,3 +257,31 @@ def faster_ndindex(shape: Sequence[int]):
https://github.com/numpy/numpy/issues/28921
https://github.com/numpy/numpy/issues/28921
"""
"""
return
product
(
*
(
range
(
s
)
for
s
in
shape
))
return
product
(
*
(
range
(
s
)
for
s
in
shape
))
def
get_static_shape_from_size_variables
(
size_vars
:
Sequence
[
Variable
],
)
->
tuple
[
int
|
None
,
...
]:
"""Get static shape from size variables.
Parameters
----------
size_vars : Sequence[Variable]
A sequence of variables representing the size of each dimension.
Returns
-------
tuple[int | None, ...]
A tuple containing the static lengths of each dimension, or None if
the length is not statically known.
"""
from
pytensor.tensor.basic
import
get_scalar_constant_value
static_lengths
:
list
[
None
|
int
]
=
[
None
]
*
len
(
size_vars
)
for
i
,
length
in
enumerate
(
size_vars
):
try
:
static_length
=
get_scalar_constant_value
(
length
)
except
NotScalarConstantError
:
pass
else
:
static_lengths
[
i
]
=
int
(
static_length
)
return
tuple
(
static_lengths
)
pytensor/xtensor/random.py
0 → 100644
浏览文件 @
7886cf83
from
collections.abc
import
Sequence
from
functools
import
wraps
from
typing
import
Literal
import
pytensor.tensor.random.basic
as
ptr
from
pytensor.graph.basic
import
Variable
from
pytensor.tensor.random.op
import
RandomVariable
from
pytensor.xtensor
import
as_xtensor
from
pytensor.xtensor.math
import
sqrt
from
pytensor.xtensor.vectorization
import
XRV
def
_as_xrv
(
core_op
:
RandomVariable
,
core_inps_dims_map
:
Sequence
[
Sequence
[
int
]]
|
None
=
None
,
core_out_dims_map
:
Sequence
[
int
]
|
None
=
None
,
):
"""Helper function to define an XRV constructor.
Parameters
----------
core_op : RandomVariable
The core random variable operation to wrap.
core_inps_dims_map : Sequence[Sequence[int]] | None, optional
A sequence of sequences mapping the core dimensions (specified by the user)
for each input parameter. This is used when lowering to a RandomVariable operation,
to decide the ordering of the core dimensions for each input.
If None, it assumes the core dimensions are positional from left to right.
core_out_dims_map : Sequence[int] | None, optional
A sequence mapping the core dimensions (specified by the user) for the output variable.
This is used when lowering to a RandomVariable operation,
to decide the ordering of the core dimensions for the output.
If None, it assumes the core dimensions are positional from left to right.
"""
if
core_inps_dims_map
is
None
:
# Assume core_dims map positionally from left to right
core_inps_dims_map
=
[
tuple
(
range
(
ndim
))
for
ndim
in
core_op
.
ndims_params
]
if
core_out_dims_map
is
None
:
# Assume core_dims map positionally from left to right
core_out_dims_map
=
tuple
(
range
(
core_op
.
ndim_supp
))
core_dims_needed
=
max
(
(
*
(
len
(
i
)
for
i
in
core_inps_dims_map
),
len
(
core_out_dims_map
)),
default
=
0
)
@wraps
(
core_op
)
def
xrv_constructor
(
*
params
,
core_dims
:
Sequence
[
str
]
|
str
|
None
=
None
,
extra_dims
:
dict
[
str
,
Variable
]
|
None
=
None
,
rng
:
Variable
|
None
=
None
,
):
if
core_dims
is
None
:
core_dims
=
()
if
core_dims_needed
:
raise
ValueError
(
f
"{core_op.name} needs {core_dims_needed} core_dims to be specified"
)
elif
isinstance
(
core_dims
,
str
):
core_dims
=
(
core_dims
,)
if
len
(
core_dims
)
!=
core_dims_needed
:
raise
ValueError
(
f
"{core_op.name} needs {core_dims_needed} core_dims, but got {len(core_dims)}"
)
full_input_core_dims
=
tuple
(
tuple
(
core_dims
[
i
]
for
i
in
inp_dims_map
)
for
inp_dims_map
in
core_inps_dims_map
)
full_output_core_dims
=
tuple
(
core_dims
[
i
]
for
i
in
core_out_dims_map
)
full_core_dims
=
(
full_input_core_dims
,
full_output_core_dims
)
if
extra_dims
is
None
:
extra_dims
=
{}
return
XRV
(
core_op
,
core_dims
=
full_core_dims
,
extra_dims
=
tuple
(
extra_dims
.
keys
())
)(
rng
,
*
extra_dims
.
values
(),
*
params
)
return
xrv_constructor
bernoulli
=
_as_xrv
(
ptr
.
bernoulli
)
beta
=
_as_xrv
(
ptr
.
beta
)
betabinom
=
_as_xrv
(
ptr
.
betabinom
)
binomial
=
_as_xrv
(
ptr
.
binomial
)
categorical
=
_as_xrv
(
ptr
.
categorical
)
cauchy
=
_as_xrv
(
ptr
.
cauchy
)
dirichlet
=
_as_xrv
(
ptr
.
dirichlet
)
exponential
=
_as_xrv
(
ptr
.
exponential
)
gamma
=
_as_xrv
(
ptr
.
_gamma
)
gengamma
=
_as_xrv
(
ptr
.
gengamma
)
geometric
=
_as_xrv
(
ptr
.
geometric
)
gumbel
=
_as_xrv
(
ptr
.
gumbel
)
halfcauchy
=
_as_xrv
(
ptr
.
halfcauchy
)
halfnormal
=
_as_xrv
(
ptr
.
halfnormal
)
hypergeometric
=
_as_xrv
(
ptr
.
hypergeometric
)
integers
=
_as_xrv
(
ptr
.
integers
)
invgamma
=
_as_xrv
(
ptr
.
invgamma
)
laplace
=
_as_xrv
(
ptr
.
laplace
)
logistic
=
_as_xrv
(
ptr
.
logistic
)
lognormal
=
_as_xrv
(
ptr
.
lognormal
)
multinomial
=
_as_xrv
(
ptr
.
multinomial
)
nbinom
=
negative_binomial
=
_as_xrv
(
ptr
.
negative_binomial
)
normal
=
_as_xrv
(
ptr
.
normal
)
pareto
=
_as_xrv
(
ptr
.
pareto
)
poisson
=
_as_xrv
(
ptr
.
poisson
)
t
=
_as_xrv
(
ptr
.
t
)
triangular
=
_as_xrv
(
ptr
.
triangular
)
truncexpon
=
_as_xrv
(
ptr
.
truncexpon
)
uniform
=
_as_xrv
(
ptr
.
uniform
)
vonmises
=
_as_xrv
(
ptr
.
vonmises
)
wald
=
_as_xrv
(
ptr
.
wald
)
weibull
=
_as_xrv
(
ptr
.
weibull
)
def
multivariate_normal
(
mean
,
cov
,
*
,
core_dims
:
Sequence
[
str
],
extra_dims
=
None
,
rng
=
None
,
method
:
Literal
[
"cholesky"
,
"svd"
,
"eigh"
]
=
"cholesky"
,
):
mean
=
as_xtensor
(
mean
)
if
len
(
core_dims
)
!=
2
:
raise
ValueError
(
f
"multivariate_normal requires 2 core_dims, got {len(core_dims)}"
)
# Align core_dims, so that the dim that exists in mean comes before the one that only exists in cov
# This will be the core dimension of the output
if
core_dims
[
0
]
not
in
mean
.
type
.
dims
:
core_dims
=
core_dims
[::
-
1
]
xop
=
_as_xrv
(
ptr
.
MvNormalRV
(
method
=
method
))
return
xop
(
mean
,
cov
,
core_dims
=
core_dims
,
extra_dims
=
extra_dims
,
rng
=
rng
)
def
standard_normal
(
extra_dims
:
dict
[
str
,
Variable
]
|
None
=
None
,
rng
:
Variable
|
None
=
None
,
):
"""Standard normal random variable."""
return
normal
(
0
,
1
,
extra_dims
=
extra_dims
,
rng
=
rng
)
def
chisquare
(
df
,
extra_dims
:
dict
[
str
,
Variable
]
|
None
=
None
,
rng
:
Variable
|
None
=
None
,
):
"""Chi-square random variable."""
return
gamma
(
df
/
2.0
,
2.0
,
extra_dims
=
extra_dims
,
rng
=
rng
)
def
rayleigh
(
scale
,
extra_dims
:
dict
[
str
,
Variable
]
|
None
=
None
,
rng
:
Variable
|
None
=
None
,
):
"""Rayleigh random variable."""
df
=
scale
*
0
+
2
# Poor man's broadcasting, to pass dimensions of scale to the RV
return
sqrt
(
chisquare
(
df
,
extra_dims
=
extra_dims
,
rng
=
rng
))
*
scale
pytensor/xtensor/rewriting/vectorization.py
浏览文件 @
7886cf83
from
pytensor.graph
import
node_rewriter
from
pytensor.graph
import
node_rewriter
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.blockwise
import
Blockwise
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.elemwise
import
Elemwise
from
pytensor.tensor.random.utils
import
compute_batch_shape
from
pytensor.xtensor.basic
import
tensor_from_xtensor
,
xtensor_from_tensor
from
pytensor.xtensor.basic
import
tensor_from_xtensor
,
xtensor_from_tensor
from
pytensor.xtensor.rewriting.utils
import
register_lower_xtensor
from
pytensor.xtensor.rewriting.utils
import
register_lower_xtensor
from
pytensor.xtensor.vectorization
import
XBlockwise
,
XElemwise
from
pytensor.xtensor.vectorization
import
X
RV
,
X
Blockwise
,
XElemwise
@register_lower_xtensor
@register_lower_xtensor
...
@@ -74,3 +75,49 @@ def lower_blockwise(fgraph, node):
...
@@ -74,3 +75,49 @@ def lower_blockwise(fgraph, node):
for
(
tensor_out
,
old_out
)
in
zip
(
tensor_outs
,
node
.
outputs
,
strict
=
True
)
for
(
tensor_out
,
old_out
)
in
zip
(
tensor_outs
,
node
.
outputs
,
strict
=
True
)
]
]
return
new_outs
return
new_outs
@register_lower_xtensor
@node_rewriter
(
tracks
=
[
XRV
])
def
lower_rv
(
fgraph
,
node
):
op
:
XRV
=
node
.
op
core_op
=
op
.
core_op
_
,
old_out
=
node
.
outputs
rng
,
*
extra_dim_lengths_and_params
=
node
.
inputs
extra_dim_lengths
=
extra_dim_lengths_and_params
[:
len
(
op
.
extra_dims
)]
params
=
extra_dim_lengths_and_params
[
len
(
op
.
extra_dims
)
:]
batch_ndim
=
old_out
.
type
.
ndim
-
len
(
op
.
core_dims
[
1
])
param_batch_dims
=
old_out
.
type
.
dims
[
len
(
op
.
extra_dims
)
:
batch_ndim
]
# Convert params Tensors to XTensors, align batch dimensions and place core dimension at the end
tensor_params
=
[]
for
inp
,
core_dims
in
zip
(
params
,
op
.
core_dims
[
0
]):
inp_dims
=
inp
.
type
.
dims
# Align the batch dims of the input, and place the core dims on the right
batch_order
=
[
inp_dims
.
index
(
batch_dim
)
if
batch_dim
in
inp_dims
else
"x"
for
batch_dim
in
param_batch_dims
]
core_order
=
[
inp_dims
.
index
(
core_dim
)
for
core_dim
in
core_dims
]
tensor_inp
=
tensor_from_xtensor
(
inp
)
.
dimshuffle
(
batch_order
+
core_order
)
tensor_params
.
append
(
tensor_inp
)
size
=
None
if
op
.
extra_dims
:
# RV size contains the lengths of all batch dimensions, including those coming from the parameters
if
tensor_params
:
param_batch_shape
=
tuple
(
compute_batch_shape
(
tensor_params
,
ndims_params
=
core_op
.
ndims_params
)
)
else
:
param_batch_shape
=
()
size
=
[
*
extra_dim_lengths
,
*
param_batch_shape
]
# RVs are their own core Op
new_next_rng
,
tensor_out
=
core_op
(
*
tensor_params
,
rng
=
rng
,
size
=
size
)
.
owner
.
outputs
# Convert output Tensors to XTensors
new_out
=
xtensor_from_tensor
(
tensor_out
,
dims
=
old_out
.
type
.
dims
)
return
[
new_next_rng
,
new_out
]
pytensor/xtensor/shape.py
浏览文件 @
7886cf83
...
@@ -11,6 +11,7 @@ from pytensor.scalar import discrete_dtypes, upcast
...
@@ -11,6 +11,7 @@ from pytensor.scalar import discrete_dtypes, upcast
from
pytensor.tensor
import
as_tensor
,
get_scalar_constant_value
from
pytensor.tensor
import
as_tensor
,
get_scalar_constant_value
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.exceptions
import
NotScalarConstantError
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.tensor.type
import
integer_dtypes
from
pytensor.tensor.utils
import
get_static_shape_from_size_variables
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
...
@@ -131,14 +132,9 @@ class UnStack(XOp):
...
@@ -131,14 +132,9 @@ class UnStack(XOp):
)
)
)
)
static_unstacked_lengths
=
[
None
]
*
len
(
unstacked_lengths
)
static_unstacked_lengths
=
get_static_shape_from_size_variables
(
for
i
,
length
in
enumerate
(
unstacked_lengths
):
unstacked_lengths
try
:
)
static_length
=
get_scalar_constant_value
(
length
)
except
NotScalarConstantError
:
pass
else
:
static_unstacked_lengths
[
i
]
=
int
(
static_length
)
output
=
xtensor
(
output
=
xtensor
(
dtype
=
x
.
type
.
dtype
,
dtype
=
x
.
type
.
dtype
,
...
...
pytensor/xtensor/vectorization.py
浏览文件 @
7886cf83
from
itertools
import
chain
from
itertools
import
chain
import
numpy
as
np
from
pytensor
import
scalar
as
ps
from
pytensor
import
scalar
as
ps
from
pytensor
import
shared
from
pytensor.graph
import
Apply
,
Op
from
pytensor.graph
import
Apply
,
Op
from
pytensor.scalar
import
discrete_dtypes
from
pytensor.tensor
import
tensor
from
pytensor.tensor
import
tensor
from
pytensor.tensor.random.op
import
RNGConsumerOp
from
pytensor.tensor.random.type
import
RandomType
from
pytensor.tensor.utils
import
(
get_static_shape_from_size_variables
,
)
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.basic
import
XOp
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
from
pytensor.xtensor.type
import
as_xtensor
,
xtensor
...
@@ -108,3 +117,139 @@ class XBlockwise(XOp):
...
@@ -108,3 +117,139 @@ class XBlockwise(XOp):
for
core_out
,
core_out_dims
in
zip
(
core_node
.
outputs
,
core_outputs_dims
)
for
core_out
,
core_out_dims
in
zip
(
core_node
.
outputs
,
core_outputs_dims
)
]
]
return
Apply
(
self
,
inputs
,
outputs
)
return
Apply
(
self
,
inputs
,
outputs
)
class
XRV
(
XOp
,
RNGConsumerOp
):
"""Wrapper for RandomVariable operations that follows xarray-like broadcasting semantics.
Xarray does not offer random generators, so this class implements a new API.
It mostly works like a gufunc (or XBlockwise), which specifies core dimensions for inputs and output, and
enforces dim-based broadcasting between inputs and output.
It differs from XBlockwise in a couple of ways:
1. It is restricted to one sample output
2. It takes a random generator as the first input and returns the consumed generator as the first output.
3. It has the concept of extra dimensions, which determine extra batch dimensions of the output, that are not
implied by batch dimensions of the parameters.
"""
default_output
=
1
__props__
=
(
"core_op"
,
"core_dims"
,
"extra_dims"
)
def
__init__
(
self
,
core_op
,
core_dims
:
tuple
[
tuple
[
tuple
[
str
,
...
],
...
],
tuple
[
str
,
...
]],
extra_dims
:
tuple
[
str
,
...
],
):
super
()
.
__init__
()
self
.
core_op
=
core_op
inps_core_dims
,
out_core_dims
=
core_dims
for
operand_dims
in
(
*
inps_core_dims
,
out_core_dims
):
if
len
(
set
(
operand_dims
))
!=
len
(
operand_dims
):
raise
ValueError
(
f
"Operand has repeated dims {operand_dims}"
)
self
.
core_dims
=
(
tuple
(
i
for
i
in
inps_core_dims
),
tuple
(
out_core_dims
))
if
len
(
set
(
extra_dims
))
!=
len
(
extra_dims
):
raise
ValueError
(
"size_dims must be unique"
)
self
.
extra_dims
=
tuple
(
extra_dims
)
def
update
(
self
,
node
):
# RNG input and update are the first input and output respectively
return
{
node
.
inputs
[
0
]:
node
.
outputs
[
0
]}
def
make_node
(
self
,
rng
,
*
extra_dim_lengths_and_params
):
if
rng
is
None
:
rng
=
shared
(
np
.
random
.
default_rng
())
elif
not
isinstance
(
rng
.
type
,
RandomType
):
raise
TypeError
(
"The type of rng should be an instance of RandomGeneratorType "
)
extra_dim_lengths
=
[
as_xtensor
(
dim_length
)
.
values
for
dim_length
in
extra_dim_lengths_and_params
[:
len
(
self
.
extra_dims
)]
]
if
not
all
(
(
dim_length
.
type
.
ndim
==
0
and
dim_length
.
type
.
dtype
in
discrete_dtypes
)
for
dim_length
in
extra_dim_lengths
):
raise
TypeError
(
"All dimension lengths should be scalar discrete dtype."
)
params
=
[
as_xtensor
(
param
)
for
param
in
extra_dim_lengths_and_params
[
len
(
self
.
extra_dims
)
:]
]
if
len
(
params
)
!=
len
(
self
.
core_op
.
ndims_params
):
raise
ValueError
(
f
"Expected {len(self.core_op.ndims_params)} parameters + {len(self.extra_dims)} dim_lengths, "
f
"got {len(extra_dim_lengths_and_params)}"
)
param_core_dims
,
output_core_dims
=
self
.
core_dims
input_core_dims_set
=
set
(
chain
.
from_iterable
(
param_core_dims
))
# Check parameters don't have core dimensions they shouldn't have
for
param
,
core_param_dims
in
zip
(
params
,
param_core_dims
):
if
invalid_core_dims
:
=
(
set
(
param
.
type
.
dims
)
-
set
(
core_param_dims
)
)
.
intersection
(
input_core_dims_set
):
raise
ValueError
(
f
"Parameter {param} has invalid core dimensions {sorted(invalid_core_dims)}"
)
extra_dims_and_shape
=
dict
(
zip
(
self
.
extra_dims
,
get_static_shape_from_size_variables
(
extra_dim_lengths
)
)
)
params_dims_and_shape
=
combine_dims_and_shape
(
params
)
# Check that no parameter dims conflict with size dims
if
conflict_dims
:
=
set
(
extra_dims_and_shape
)
.
intersection
(
params_dims_and_shape
):
raise
ValueError
(
f
"Size dimensions {sorted(conflict_dims)} conflict with parameter dimensions. They should be unique."
)
batch_dims_and_shape
=
[
(
dim
,
dim_length
)
for
dim
,
dim_length
in
(
extra_dims_and_shape
|
params_dims_and_shape
)
.
items
()
if
dim
not
in
input_core_dims_set
]
if
batch_dims_and_shape
:
batch_output_dims
,
batch_output_shape
=
zip
(
*
batch_dims_and_shape
)
else
:
batch_output_dims
,
batch_output_shape
=
(),
()
dummy_core_inputs
=
[]
for
param
,
core_param_dims
in
zip
(
params
,
param_core_dims
):
try
:
core_static_shape
=
[
param
.
type
.
shape
[
param
.
type
.
dims
.
index
(
d
)]
for
d
in
core_param_dims
]
except
ValueError
:
raise
ValueError
(
f
"At least one core dim={core_param_dims} missing from input {param} with dims={param.type.dims}"
)
dummy_core_inputs
.
append
(
tensor
(
dtype
=
param
.
type
.
dtype
,
shape
=
core_static_shape
)
)
core_node
=
self
.
core_op
.
make_node
(
rng
,
None
,
*
dummy_core_inputs
)
if
not
len
(
core_node
.
outputs
)
==
2
:
raise
NotImplementedError
(
"XRandomVariable only supports core ops with two outputs (rng, out)"
)
_
,
core_out
=
core_node
.
outputs
out
=
xtensor
(
dtype
=
core_out
.
type
.
dtype
,
shape
=
batch_output_shape
+
core_out
.
type
.
shape
,
dims
=
batch_output_dims
+
output_core_dims
,
)
return
Apply
(
self
,
[
rng
,
*
extra_dim_lengths
,
*
params
],
[
rng
.
type
(),
out
])
tests/xtensor/test_random.py
0 → 100644
浏览文件 @
7886cf83
差异被折叠。
点击展开。
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论