Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
38c04c96
提交
38c04c96
authored
5月 24, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add explicit expand_dims when building RandomVariable nodes
上级
591c47e6
显示空白字符变更
内嵌
并排
正在显示
7 个修改的文件
包含
65 行增加
和
62 行删除
+65
-62
random.py
pytensor/link/jax/dispatch/random.py
+8
-17
random.py
pytensor/link/numba/dispatch/random.py
+0
-6
basic.py
pytensor/tensor/random/basic.py
+21
-20
op.py
pytensor/tensor/random/op.py
+12
-14
jax.py
pytensor/tensor/random/rewriting/jax.py
+8
-2
test_random.py
tests/link/numba/test_random.py
+6
-1
test_basic.py
tests/tensor/random/rewriting/test_basic.py
+10
-2
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
38c04c96
...
@@ -304,7 +304,6 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
...
@@ -304,7 +304,6 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
"""JAX implementation of `ChoiceRV`."""
"""JAX implementation of `ChoiceRV`."""
batch_ndim
=
op
.
batch_ndim
(
node
)
batch_ndim
=
op
.
batch_ndim
(
node
)
a
,
*
p
,
core_shape
=
op
.
dist_params
(
node
)
a_core_ndim
,
*
p_core_ndim
,
_
=
op
.
ndims_params
a_core_ndim
,
*
p_core_ndim
,
_
=
op
.
ndims_params
if
batch_ndim
and
a_core_ndim
==
0
:
if
batch_ndim
and
a_core_ndim
==
0
:
...
@@ -313,12 +312,6 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
...
@@ -313,12 +312,6 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
"A default JAX rewrite should have materialized the implicit arange"
"A default JAX rewrite should have materialized the implicit arange"
)
)
a_batch_ndim
=
a
.
type
.
ndim
-
a_core_ndim
if
op
.
has_p_param
:
[
p
]
=
p
[
p_core_ndim
]
=
p_core_ndim
p_batch_ndim
=
p
.
type
.
ndim
-
p_core_ndim
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
rng_key
=
rng
[
"jax_state"
]
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
rng_key
,
sampling_key
=
jax
.
random
.
split
(
rng_key
,
2
)
...
@@ -328,7 +321,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
...
@@ -328,7 +321,7 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
else
:
else
:
a
,
core_shape
=
parameters
a
,
core_shape
=
parameters
p
=
None
p
=
None
core_shape
=
tuple
(
np
.
asarray
(
core_shape
))
core_shape
=
tuple
(
np
.
asarray
(
core_shape
)
[(
0
,)
*
batch_ndim
]
)
if
batch_ndim
==
0
:
if
batch_ndim
==
0
:
sample
=
jax
.
random
.
choice
(
sample
=
jax
.
random
.
choice
(
...
@@ -338,16 +331,16 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
...
@@ -338,16 +331,16 @@ def jax_funcify_choice(op: ptr.ChoiceWithoutReplacement, node):
else
:
else
:
if
size
is
None
:
if
size
is
None
:
if
p
is
None
:
if
p
is
None
:
size
=
a
.
shape
[:
a_
batch_ndim
]
size
=
a
.
shape
[:
batch_ndim
]
else
:
else
:
size
=
jax
.
numpy
.
broadcast_shapes
(
size
=
jax
.
numpy
.
broadcast_shapes
(
a
.
shape
[:
a_
batch_ndim
],
a
.
shape
[:
batch_ndim
],
p
.
shape
[:
p_
batch_ndim
],
p
.
shape
[:
batch_ndim
],
)
)
a
=
jax
.
numpy
.
broadcast_to
(
a
,
size
+
a
.
shape
[
a_
batch_ndim
:])
a
=
jax
.
numpy
.
broadcast_to
(
a
,
size
+
a
.
shape
[
batch_ndim
:])
if
p
is
not
None
:
if
p
is
not
None
:
p
=
jax
.
numpy
.
broadcast_to
(
p
,
size
+
p
.
shape
[
p_
batch_ndim
:])
p
=
jax
.
numpy
.
broadcast_to
(
p
,
size
+
p
.
shape
[
batch_ndim
:])
batch_sampling_keys
=
jax
.
random
.
split
(
sampling_key
,
np
.
prod
(
size
))
batch_sampling_keys
=
jax
.
random
.
split
(
sampling_key
,
np
.
prod
(
size
))
...
@@ -381,7 +374,6 @@ def jax_sample_fn_permutation(op, node):
...
@@ -381,7 +374,6 @@ def jax_sample_fn_permutation(op, node):
"""JAX implementation of `PermutationRV`."""
"""JAX implementation of `PermutationRV`."""
batch_ndim
=
op
.
batch_ndim
(
node
)
batch_ndim
=
op
.
batch_ndim
(
node
)
x_batch_ndim
=
node
.
inputs
[
-
1
]
.
type
.
ndim
-
op
.
ndims_params
[
0
]
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
rng_key
=
rng
[
"jax_state"
]
rng_key
=
rng
[
"jax_state"
]
...
@@ -389,11 +381,10 @@ def jax_sample_fn_permutation(op, node):
...
@@ -389,11 +381,10 @@ def jax_sample_fn_permutation(op, node):
(
x
,)
=
parameters
(
x
,)
=
parameters
if
batch_ndim
:
if
batch_ndim
:
# jax.random.permutation has no concept of batch dims
# jax.random.permutation has no concept of batch dims
x_core_shape
=
x
.
shape
[
x_batch_ndim
:]
if
size
is
None
:
if
size
is
None
:
size
=
x
.
shape
[:
x_
batch_ndim
]
size
=
x
.
shape
[:
batch_ndim
]
else
:
else
:
x
=
jax
.
numpy
.
broadcast_to
(
x
,
size
+
x
_core_shape
)
x
=
jax
.
numpy
.
broadcast_to
(
x
,
size
+
x
.
shape
[
batch_ndim
:]
)
batch_sampling_keys
=
jax
.
random
.
split
(
sampling_key
,
np
.
prod
(
size
))
batch_sampling_keys
=
jax
.
random
.
split
(
sampling_key
,
np
.
prod
(
size
))
raveled_batch_x
=
x
.
reshape
((
-
1
,)
+
x
.
shape
[
batch_ndim
:])
raveled_batch_x
=
x
.
reshape
((
-
1
,)
+
x
.
shape
[
batch_ndim
:])
...
...
pytensor/link/numba/dispatch/random.py
浏览文件 @
38c04c96
...
@@ -347,7 +347,6 @@ def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
...
@@ -347,7 +347,6 @@ def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
def
numba_funcify_DirichletRV
(
op
,
node
,
**
kwargs
):
def
numba_funcify_DirichletRV
(
op
,
node
,
**
kwargs
):
out_dtype
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
out_dtype
=
node
.
outputs
[
1
]
.
type
.
numpy_dtype
alphas_ndim
=
op
.
dist_params
(
node
)[
0
]
.
type
.
ndim
alphas_ndim
=
op
.
dist_params
(
node
)[
0
]
.
type
.
ndim
neg_ind_shape_len
=
-
alphas_ndim
+
1
size_param
=
op
.
size_param
(
node
)
size_param
=
op
.
size_param
(
node
)
size_len
=
(
size_len
=
(
None
None
...
@@ -363,11 +362,6 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
...
@@ -363,11 +362,6 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
samples_shape
=
alphas
.
shape
samples_shape
=
alphas
.
shape
else
:
else
:
size_tpl
=
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
)
size_tpl
=
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
)
if
(
0
<
alphas
.
ndim
-
1
<=
len
(
size_tpl
)
and
size_tpl
[
neg_ind_shape_len
:]
!=
alphas
.
shape
[:
-
1
]
):
raise
ValueError
(
"Parameters shape and size do not match."
)
samples_shape
=
size_tpl
+
alphas
.
shape
[
-
1
:]
samples_shape
=
size_tpl
+
alphas
.
shape
[
-
1
:]
res
=
np
.
empty
(
samples_shape
,
dtype
=
out_dtype
)
res
=
np
.
empty
(
samples_shape
,
dtype
=
out_dtype
)
...
...
pytensor/tensor/random/basic.py
浏览文件 @
38c04c96
...
@@ -2002,6 +2002,11 @@ class ChoiceWithoutReplacement(RandomVariable):
...
@@ -2002,6 +2002,11 @@ class ChoiceWithoutReplacement(RandomVariable):
a_shape
=
tuple
(
a
.
shape
)
if
param_shapes
is
None
else
tuple
(
param_shapes
[
0
])
a_shape
=
tuple
(
a
.
shape
)
if
param_shapes
is
None
else
tuple
(
param_shapes
[
0
])
a_batch_ndim
=
len
(
a_shape
)
-
self
.
ndims_params
[
0
]
a_batch_ndim
=
len
(
a_shape
)
-
self
.
ndims_params
[
0
]
a_core_shape
=
a_shape
[
a_batch_ndim
:]
a_core_shape
=
a_shape
[
a_batch_ndim
:]
core_shape_ndim
=
core_shape
.
type
.
ndim
if
core_shape_ndim
>
1
:
# Batch core shapes are only valid if homogeneous or broadcasted,
# as otherwise they would imply ragged choice arrays
core_shape
=
core_shape
[(
0
,)
*
(
core_shape_ndim
-
1
)]
return
tuple
(
core_shape
)
+
a_core_shape
[
1
:]
return
tuple
(
core_shape
)
+
a_core_shape
[
1
:]
def
rng_fn
(
self
,
*
params
):
def
rng_fn
(
self
,
*
params
):
...
@@ -2011,15 +2016,11 @@ class ChoiceWithoutReplacement(RandomVariable):
...
@@ -2011,15 +2016,11 @@ class ChoiceWithoutReplacement(RandomVariable):
rng
,
a
,
core_shape
,
size
=
params
rng
,
a
,
core_shape
,
size
=
params
p
=
None
p
=
None
if
core_shape
.
ndim
>
1
:
core_shape
=
core_shape
[(
0
,)
*
(
core_shape
.
ndim
-
1
)]
core_shape
=
tuple
(
core_shape
)
core_shape
=
tuple
(
core_shape
)
# We don't have access to the node in rng_fn for easy computation of batch_ndim :(
batch_ndim
=
a
.
ndim
-
self
.
ndims_params
[
0
]
a_batch_ndim
=
batch_ndim
=
a
.
ndim
-
self
.
ndims_params
[
0
]
if
p
is
not
None
:
p_batch_ndim
=
p
.
ndim
-
self
.
ndims_params
[
1
]
batch_ndim
=
max
(
batch_ndim
,
p_batch_ndim
)
size_ndim
=
0
if
size
is
None
else
len
(
size
)
batch_ndim
=
max
(
batch_ndim
,
size_ndim
)
if
batch_ndim
==
0
:
if
batch_ndim
==
0
:
# Numpy choice fails with size=() if a.ndim > 1 is batched
# Numpy choice fails with size=() if a.ndim > 1 is batched
...
@@ -2031,16 +2032,16 @@ class ChoiceWithoutReplacement(RandomVariable):
...
@@ -2031,16 +2032,16 @@ class ChoiceWithoutReplacement(RandomVariable):
# Numpy choice doesn't have a concept of batch dims
# Numpy choice doesn't have a concept of batch dims
if
size
is
None
:
if
size
is
None
:
if
p
is
None
:
if
p
is
None
:
size
=
a
.
shape
[:
a_
batch_ndim
]
size
=
a
.
shape
[:
batch_ndim
]
else
:
else
:
size
=
np
.
broadcast_shapes
(
size
=
np
.
broadcast_shapes
(
a
.
shape
[:
a_
batch_ndim
],
a
.
shape
[:
batch_ndim
],
p
.
shape
[:
p_
batch_ndim
],
p
.
shape
[:
batch_ndim
],
)
)
a
=
np
.
broadcast_to
(
a
,
size
+
a
.
shape
[
a_
batch_ndim
:])
a
=
np
.
broadcast_to
(
a
,
size
+
a
.
shape
[
batch_ndim
:])
if
p
is
not
None
:
if
p
is
not
None
:
p
=
np
.
broadcast_to
(
p
,
size
+
p
.
shape
[
p_
batch_ndim
:])
p
=
np
.
broadcast_to
(
p
,
size
+
p
.
shape
[
batch_ndim
:])
a_indexed_shape
=
a
.
shape
[
len
(
size
)
+
1
:]
a_indexed_shape
=
a
.
shape
[
len
(
size
)
+
1
:]
out
=
np
.
empty
(
size
+
core_shape
+
a_indexed_shape
,
dtype
=
a
.
dtype
)
out
=
np
.
empty
(
size
+
core_shape
+
a_indexed_shape
,
dtype
=
a
.
dtype
)
...
@@ -2143,26 +2144,26 @@ class PermutationRV(RandomVariable):
...
@@ -2143,26 +2144,26 @@ class PermutationRV(RandomVariable):
def
_supp_shape_from_params
(
self
,
dist_params
,
param_shapes
=
None
):
def
_supp_shape_from_params
(
self
,
dist_params
,
param_shapes
=
None
):
[
x
]
=
dist_params
[
x
]
=
dist_params
x_shape
=
tuple
(
x
.
shape
if
param_shapes
is
None
else
param_shapes
[
0
])
x_shape
=
tuple
(
x
.
shape
if
param_shapes
is
None
else
param_shapes
[
0
])
if
x
.
type
.
ndim
==
0
:
if
self
.
ndims_params
[
0
]
==
0
:
return
(
x
,)
# Implicit arange, this is only valid for homogeneous arrays
# Otherwise it would imply a ragged permutation array.
return
(
x
.
ravel
()[
0
],)
else
:
else
:
batch_x_ndim
=
x
.
type
.
ndim
-
self
.
ndims_params
[
0
]
batch_x_ndim
=
x
.
type
.
ndim
-
self
.
ndims_params
[
0
]
return
x_shape
[
batch_x_ndim
:]
return
x_shape
[
batch_x_ndim
:]
def
rng_fn
(
self
,
rng
,
x
,
size
):
def
rng_fn
(
self
,
rng
,
x
,
size
):
# We don't have access to the node in rng_fn :(
# We don't have access to the node in rng_fn :(
x_batch_ndim
=
x
.
ndim
-
self
.
ndims_params
[
0
]
batch_ndim
=
x
.
ndim
-
self
.
ndims_params
[
0
]
batch_ndim
=
max
(
x_batch_ndim
,
0
if
size
is
None
else
len
(
size
))
if
batch_ndim
:
if
batch_ndim
:
# rng.permutation has no concept of batch dims
# rng.permutation has no concept of batch dims
x_core_shape
=
x
.
shape
[
x_batch_ndim
:]
if
size
is
None
:
if
size
is
None
:
size
=
x
.
shape
[:
x_
batch_ndim
]
size
=
x
.
shape
[:
batch_ndim
]
else
:
else
:
x
=
np
.
broadcast_to
(
x
,
size
+
x
_core_shape
)
x
=
np
.
broadcast_to
(
x
,
size
+
x
.
shape
[
batch_ndim
:]
)
out
=
np
.
empty
(
size
+
x
_core_shape
,
dtype
=
x
.
dtype
)
out
=
np
.
empty
(
size
+
x
.
shape
[
batch_ndim
:]
,
dtype
=
x
.
dtype
)
for
idx
in
np
.
ndindex
(
size
):
for
idx
in
np
.
ndindex
(
size
):
out
[
idx
]
=
rng
.
permutation
(
x
[
idx
])
out
[
idx
]
=
rng
.
permutation
(
x
[
idx
])
return
out
return
out
...
...
pytensor/tensor/random/op.py
浏览文件 @
38c04c96
...
@@ -9,7 +9,7 @@ import pytensor
...
@@ -9,7 +9,7 @@ import pytensor
from
pytensor.configdefaults
import
config
from
pytensor.configdefaults
import
config
from
pytensor.graph.basic
import
Apply
,
Variable
,
equal_computations
from
pytensor.graph.basic
import
Apply
,
Variable
,
equal_computations
from
pytensor.graph.op
import
Op
from
pytensor.graph.op
import
Op
from
pytensor.graph.replace
import
_vectorize_node
,
vectorize_graph
from
pytensor.graph.replace
import
_vectorize_node
from
pytensor.misc.safe_asarray
import
_asarray
from
pytensor.misc.safe_asarray
import
_asarray
from
pytensor.scalar
import
ScalarVariable
from
pytensor.scalar
import
ScalarVariable
from
pytensor.tensor.basic
import
(
from
pytensor.tensor.basic
import
(
...
@@ -359,6 +359,12 @@ class RandomVariable(Op):
...
@@ -359,6 +359,12 @@ class RandomVariable(Op):
inferred_shape
=
self
.
_infer_shape
(
size
,
dist_params
)
inferred_shape
=
self
.
_infer_shape
(
size
,
dist_params
)
_
,
static_shape
=
infer_static_shape
(
inferred_shape
)
_
,
static_shape
=
infer_static_shape
(
inferred_shape
)
dist_params
=
explicit_expand_dims
(
dist_params
,
self
.
ndims_params
,
size_length
=
None
if
NoneConst
.
equals
(
size
)
else
get_vector_length
(
size
),
)
inputs
=
(
rng
,
size
,
*
dist_params
)
inputs
=
(
rng
,
size
,
*
dist_params
)
out_type
=
TensorType
(
dtype
=
self
.
dtype
,
shape
=
static_shape
)
out_type
=
TensorType
(
dtype
=
self
.
dtype
,
shape
=
static_shape
)
outputs
=
(
rng
.
type
(),
out_type
())
outputs
=
(
rng
.
type
(),
out_type
())
...
@@ -459,22 +465,14 @@ def vectorize_random_variable(
...
@@ -459,22 +465,14 @@ def vectorize_random_variable(
None
if
isinstance
(
old_size
.
type
,
NoneTypeT
)
else
get_vector_length
(
old_size
)
None
if
isinstance
(
old_size
.
type
,
NoneTypeT
)
else
get_vector_length
(
old_size
)
)
)
original_expanded_dist_params
=
explicit_expand_dims
(
if
len_old_size
and
equal_computations
([
old_size
],
[
size
]):
original_dist_params
,
op
.
ndims_params
,
len_old_size
)
# We call vectorize_graph to automatically handle any new explicit expand_dims
dist_params
=
vectorize_graph
(
original_expanded_dist_params
,
dict
(
zip
(
original_dist_params
,
dist_params
))
)
new_ndim
=
dist_params
[
0
]
.
type
.
ndim
-
original_expanded_dist_params
[
0
]
.
type
.
ndim
if
new_ndim
and
len_old_size
and
equal_computations
([
old_size
],
[
size
]):
# If the original RV had a size variable and a new one has not been provided,
# If the original RV had a size variable and a new one has not been provided,
# we need to define a new size as the concatenation of the original size dimensions
# we need to define a new size as the concatenation of the original size dimensions
# and the novel ones implied by new broadcasted batched parameters dimensions.
# and the novel ones implied by new broadcasted batched parameters dimensions.
broadcasted_batch_shape
=
compute_batch_shape
(
dist_params
,
op
.
ndims_params
)
new_ndim
=
dist_params
[
0
]
.
type
.
ndim
-
original_dist_params
[
0
]
.
type
.
ndim
new_size_dims
=
broadcasted_batch_shape
[:
new_ndim
]
if
new_ndim
>=
0
:
new_size
=
compute_batch_shape
(
dist_params
,
ndims_params
=
op
.
ndims_params
)
new_size_dims
=
new_size
[:
new_ndim
]
size
=
concatenate
([
new_size_dims
,
size
])
size
=
concatenate
([
new_size_dims
,
size
])
return
op
.
make_node
(
rng
,
size
,
*
dist_params
)
return
op
.
make_node
(
rng
,
size
,
*
dist_params
)
pytensor/tensor/random/rewriting/jax.py
浏览文件 @
38c04c96
import
re
import
re
from
pytensor.compile
import
optdb
from
pytensor.compile
import
optdb
from
pytensor.graph
import
Constant
from
pytensor.graph.rewriting.basic
import
in2out
,
node_rewriter
from
pytensor.graph.rewriting.basic
import
in2out
,
node_rewriter
from
pytensor.graph.rewriting.db
import
SequenceDB
from
pytensor.graph.rewriting.db
import
SequenceDB
from
pytensor.tensor
import
abs
as
abs_t
from
pytensor.tensor
import
abs
as
abs_t
...
@@ -159,12 +160,17 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
...
@@ -159,12 +160,17 @@ def materialize_implicit_arange_choice_without_replacement(fgraph, node):
return
None
return
None
rng
,
size
,
a_scalar_param
,
*
other_params
=
node
.
inputs
rng
,
size
,
a_scalar_param
,
*
other_params
=
node
.
inputs
if
a_scalar_param
.
type
.
ndim
>
0
:
if
not
all
(
a_scalar_param
.
type
.
broadcastable
)
:
# Automatic vectorization could have made this parameter batched,
# Automatic vectorization could have made this parameter batched,
# there is no nice way to materialize a batched arange
# there is no nice way to materialize a batched arange
return
None
return
None
a_vector_param
=
arange
(
a_scalar_param
)
# We need to try and do an eager squeeze here because arange will fail in jax
# if there is an array leading to it, even if it's constant
if
isinstance
(
a_scalar_param
,
Constant
):
a_scalar_param
=
a_scalar_param
.
data
a_vector_param
=
arange
(
a_scalar_param
.
squeeze
())
new_props_dict
=
op
.
_props_dict
()
.
copy
()
new_props_dict
=
op
.
_props_dict
()
.
copy
()
# Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)"
# Signature changes from something like "(),(a),(2)->(s0, s1)" to "(a),(a),(2)->(s0, s1)"
# I.e., we substitute the first `()` by `(a)`
# I.e., we substitute the first `()` by `(a)`
...
...
tests/link/numba/test_random.py
浏览文件 @
38c04c96
...
@@ -28,6 +28,9 @@ from tests.tensor.random.test_basic import (
...
@@ -28,6 +28,9 @@ from tests.tensor.random.test_basic import (
rng
=
np
.
random
.
default_rng
(
42849
)
rng
=
np
.
random
.
default_rng
(
42849
)
@pytest.mark.xfail
(
reason
=
"Most RVs are not working correctly with explicit expand_dims"
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"rv_op, dist_args, size"
,
"rv_op, dist_args, size"
,
[
[
...
@@ -388,6 +391,7 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
...
@@ -388,6 +391,7 @@ def test_aligned_RandomVariable(rv_op, dist_args, size):
)
)
@pytest.mark.xfail
(
reason
=
"Test is not working correctly with explicit expand_dims"
)
@pytest.mark.parametrize
(
@pytest.mark.parametrize
(
"rv_op, dist_args, base_size, cdf_name, params_conv"
,
"rv_op, dist_args, base_size, cdf_name, params_conv"
,
[
[
...
@@ -633,7 +637,7 @@ def test_CategoricalRV(dist_args, size, cm):
...
@@ -633,7 +637,7 @@ def test_CategoricalRV(dist_args, size, cm):
),
),
),
),
(
10
,
4
),
(
10
,
4
),
pytest
.
raises
(
ValueError
,
match
=
"
Parameters shape.*
"
),
pytest
.
raises
(
ValueError
,
match
=
"
operands could not be broadcast together
"
),
),
),
],
],
)
)
...
@@ -658,6 +662,7 @@ def test_DirichletRV(a, size, cm):
...
@@ -658,6 +662,7 @@ def test_DirichletRV(a, size, cm):
assert
np
.
allclose
(
res
,
exp_res
,
atol
=
1e-4
)
assert
np
.
allclose
(
res
,
exp_res
,
atol
=
1e-4
)
@pytest.mark.xfail
(
reason
=
"RandomState is not aligned with explicit expand_dims"
)
def
test_RandomState_updates
():
def
test_RandomState_updates
():
rng
=
shared
(
np
.
random
.
RandomState
(
1
))
rng
=
shared
(
np
.
random
.
RandomState
(
1
))
rng_new
=
shared
(
np
.
random
.
RandomState
(
2
))
rng_new
=
shared
(
np
.
random
.
RandomState
(
2
))
...
...
tests/tensor/random/rewriting/test_basic.py
浏览文件 @
38c04c96
...
@@ -796,13 +796,21 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
...
@@ -796,13 +796,21 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
rng
,
rng
,
)
)
def
is_subtensor_or_dimshuffle_subtensor
(
inp
)
->
bool
:
subtensor_ops
=
Subtensor
|
AdvancedSubtensor
|
AdvancedSubtensor1
if
isinstance
(
inp
.
owner
.
op
,
subtensor_ops
):
return
True
if
isinstance
(
inp
.
owner
.
op
,
DimShuffle
):
return
isinstance
(
inp
.
owner
.
inputs
[
0
]
.
owner
.
op
,
subtensor_ops
)
return
False
if
lifted
:
if
lifted
:
assert
isinstance
(
new_out
.
owner
.
op
,
RandomVariable
)
assert
isinstance
(
new_out
.
owner
.
op
,
RandomVariable
)
assert
all
(
assert
all
(
is
instance
(
i
.
owner
.
op
,
AdvancedSubtensor
|
AdvancedSubtensor1
|
Subtensor
)
is
_subtensor_or_dimshuffle_subtensor
(
i
)
for
i
in
new_out
.
owner
.
op
.
dist_params
(
new_out
.
owner
)
for
i
in
new_out
.
owner
.
op
.
dist_params
(
new_out
.
owner
)
if
i
.
owner
if
i
.
owner
)
)
,
new_out
.
dprint
(
depth
=
3
,
print_type
=
True
)
else
:
else
:
assert
isinstance
(
assert
isinstance
(
new_out
.
owner
.
op
,
AdvancedSubtensor
|
AdvancedSubtensor1
|
Subtensor
new_out
.
owner
.
op
,
AdvancedSubtensor
|
AdvancedSubtensor1
|
Subtensor
...
...
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论