Skip to content
项目
群组
代码片段
帮助
当前项目
正在载入...
登录 / 注册
切换导航面板
P
pytensor
项目
项目
详情
活动
周期分析
仓库
仓库
文件
提交
分支
标签
贡献者
图表
比较
统计图
议题
0
议题
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
CI / CD
CI / CD
流水线
作业
日程
统计图
Wiki
Wiki
代码片段
代码片段
成员
成员
折叠边栏
关闭边栏
活动
图像
聊天
创建新问题
作业
提交
问题看板
Open sidebar
testgroup
pytensor
Commits
98d73d78
提交
98d73d78
authored
5月 10, 2024
作者:
Ricardo Vieira
提交者:
Ricardo Vieira
5月 29, 2024
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Remove RandomVariable dtype input
上级
df32683c
显示空白字符变更
内嵌
并排
正在显示
6 个修改的文件
包含
118 行增加
和
122 行删除
+118
-122
random.py
pytensor/link/jax/dispatch/random.py
+2
-2
random.py
pytensor/link/numba/dispatch/random.py
+8
-9
op.py
pytensor/tensor/random/op.py
+35
-33
basic.py
pytensor/tensor/random/rewriting/basic.py
+6
-6
test_basic.py
tests/tensor/random/rewriting/test_basic.py
+6
-5
test_op.py
tests/tensor/random/test_op.py
+61
-67
没有找到文件。
pytensor/link/jax/dispatch/random.py
浏览文件 @
98d73d78
...
...
@@ -114,7 +114,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
if
None
in
static_size
:
assert_size_argument_jax_compatible
(
node
)
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
def
sample_fn
(
rng
,
size
,
*
parameters
):
# PyTensor uses empty size to represent size = None
if
jax
.
numpy
.
asarray
(
size
)
.
shape
==
(
0
,):
size
=
None
...
...
@@ -122,7 +122,7 @@ def jax_funcify_RandomVariable(op: ptr.RandomVariable, node, **kwargs):
else
:
def
sample_fn
(
rng
,
size
,
dtype
,
*
parameters
):
def
sample_fn
(
rng
,
size
,
*
parameters
):
return
jax_sample_fn
(
op
,
node
=
node
)(
rng
,
static_size
,
out_dtype
,
*
parameters
)
...
...
pytensor/link/numba/dispatch/random.py
浏览文件 @
98d73d78
...
...
@@ -123,7 +123,6 @@ def make_numba_random_fn(node, np_random_func):
"size_dims"
,
"rng"
,
"size"
,
"dtype"
,
],
suffix_sep
=
"_"
,
)
...
...
@@ -146,7 +145,7 @@ def {bcast_fn_name}({bcast_fn_input_names}):
)
random_fn_input_names
=
", "
.
join
(
[
"rng"
,
"size"
,
"dtype"
]
+
[
unique_names
(
i
)
for
i
in
dist_params
]
[
"rng"
,
"size"
]
+
[
unique_names
(
i
)
for
i
in
dist_params
]
)
# Now, create a Numba JITable function that implements the `size` parameter
...
...
@@ -243,7 +242,7 @@ def create_numba_random_fn(
np_global_env
[
"numba_vectorize"
]
=
numba_basic
.
numba_vectorize
unique_names
=
unique_name_generator
(
[
np_random_fn_name
,
*
np_global_env
.
keys
(),
"rng"
,
"size"
,
"dtype"
],
[
np_random_fn_name
,
*
np_global_env
.
keys
(),
"rng"
,
"size"
],
suffix_sep
=
"_"
,
)
...
...
@@ -310,7 +309,7 @@ def numba_funcify_CategoricalRV(op: ptr.CategoricalRV, node, **kwargs):
p_ndim
=
node
.
inputs
[
-
1
]
.
ndim
@numba_basic.numba_njit
def
categorical_rv
(
rng
,
size
,
dtype
,
p
):
def
categorical_rv
(
rng
,
size
,
p
):
if
not
size_len
:
size_tpl
=
p
.
shape
[:
-
1
]
else
:
...
...
@@ -342,7 +341,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
if
alphas_ndim
>
1
:
@numba_basic.numba_njit
def
dirichlet_rv
(
rng
,
size
,
dtype
,
alphas
):
def
dirichlet_rv
(
rng
,
size
,
alphas
):
if
size_len
>
0
:
size_tpl
=
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
)
if
(
...
...
@@ -365,7 +364,7 @@ def numba_funcify_DirichletRV(op, node, **kwargs):
else
:
@numba_basic.numba_njit
def
dirichlet_rv
(
rng
,
size
,
dtype
,
alphas
):
def
dirichlet_rv
(
rng
,
size
,
alphas
):
size
=
numba_ndarray
.
to_fixed_tuple
(
size
,
size_len
)
return
(
rng
,
np
.
random
.
dirichlet
(
alphas
,
size
))
...
...
@@ -388,14 +387,14 @@ def numba_funcify_choice_without_replacement(op, node, **kwargs):
if
op
.
has_p_param
:
@numba_basic.numba_njit
def
choice_without_replacement_rv
(
rng
,
size
,
dtype
,
a
,
p
,
core_shape
):
def
choice_without_replacement_rv
(
rng
,
size
,
a
,
p
,
core_shape
):
core_shape
=
numba_ndarray
.
to_fixed_tuple
(
core_shape
,
core_shape_len
)
samples
=
np
.
random
.
choice
(
a
,
size
=
core_shape
,
replace
=
False
,
p
=
p
)
return
(
rng
,
samples
)
else
:
@numba_basic.numba_njit
def
choice_without_replacement_rv
(
rng
,
size
,
dtype
,
a
,
core_shape
):
def
choice_without_replacement_rv
(
rng
,
size
,
a
,
core_shape
):
core_shape
=
numba_ndarray
.
to_fixed_tuple
(
core_shape
,
core_shape_len
)
samples
=
np
.
random
.
choice
(
a
,
size
=
core_shape
,
replace
=
False
)
return
(
rng
,
samples
)
...
...
@@ -411,7 +410,7 @@ def numba_funcify_permutation(op: ptr.PermutationRV, node, **kwargs):
x_batch_ndim
=
node
.
inputs
[
-
1
]
.
type
.
ndim
-
op
.
ndims_params
[
0
]
@numba_basic.numba_njit
def
permutation_rv
(
rng
,
size
,
dtype
,
x
):
def
permutation_rv
(
rng
,
size
,
x
):
if
batch_ndim
:
x_core_shape
=
x
.
shape
[
x_batch_ndim
:]
if
size_is_none
:
...
...
pytensor/tensor/random/op.py
浏览文件 @
98d73d78
...
...
@@ -27,7 +27,7 @@ from pytensor.tensor.random.utils import (
normalize_size_param
,
)
from
pytensor.tensor.shape
import
shape_tuple
from
pytensor.tensor.type
import
TensorType
,
all_dtypes
from
pytensor.tensor.type
import
TensorType
from
pytensor.tensor.type_other
import
NoneConst
from
pytensor.tensor.utils
import
_parse_gufunc_signature
,
safe_signature
from
pytensor.tensor.variable
import
TensorVariable
...
...
@@ -65,7 +65,7 @@ class RandomVariable(Op):
signature: str
Numpy-like vectorized signature of the random variable.
dtype: str (optional)
The dtype of the sampled output. If the value ``"floatX"`` is
The d
efault d
type of the sampled output. If the value ``"floatX"`` is
given, then ``dtype`` is set to ``pytensor.config.floatX``. If
``None`` (the default), the `dtype` keyword must be set when
`RandomVariable.make_node` is called.
...
...
@@ -287,8 +287,8 @@ class RandomVariable(Op):
return
shape
def
infer_shape
(
self
,
fgraph
,
node
,
input_shapes
):
_
,
size
,
_
,
*
dist_params
=
node
.
inputs
_
,
size_shape
,
_
,
*
param_shapes
=
input_shapes
_
,
size
,
*
dist_params
=
node
.
inputs
_
,
size_shape
,
*
param_shapes
=
input_shapes
try
:
size_len
=
get_vector_length
(
size
)
...
...
@@ -302,14 +302,34 @@ class RandomVariable(Op):
return
[
None
,
list
(
shape
)]
def
__call__
(
self
,
*
args
,
size
=
None
,
name
=
None
,
rng
=
None
,
dtype
=
None
,
**
kwargs
):
res
=
super
()
.
__call__
(
rng
,
size
,
dtype
,
*
args
,
**
kwargs
)
if
dtype
is
None
:
dtype
=
self
.
dtype
if
dtype
==
"floatX"
:
dtype
=
config
.
floatX
# We need to recreate the Op with the right dtype
if
dtype
!=
self
.
dtype
:
# Check we are not switching from float to int
if
self
.
dtype
is
not
None
:
if
dtype
.
startswith
(
"float"
)
!=
self
.
dtype
.
startswith
(
"float"
):
raise
ValueError
(
f
"Cannot change the dtype of a {self.name} RV from {self.dtype} to {dtype}"
)
props
=
self
.
_props_dict
()
props
[
"dtype"
]
=
dtype
new_op
=
type
(
self
)(
**
props
)
return
new_op
.
__call__
(
*
args
,
size
=
size
,
name
=
name
,
rng
=
rng
,
dtype
=
dtype
,
**
kwargs
)
res
=
super
()
.
__call__
(
rng
,
size
,
*
args
,
**
kwargs
)
if
name
is
not
None
:
res
.
name
=
name
return
res
def
make_node
(
self
,
rng
,
size
,
dtype
,
*
dist_params
):
def
make_node
(
self
,
rng
,
size
,
*
dist_params
):
"""Create a random variable node.
Parameters
...
...
@@ -349,23 +369,10 @@ class RandomVariable(Op):
shape
=
self
.
_infer_shape
(
size
,
dist_params
)
_
,
static_shape
=
infer_static_shape
(
shape
)
dtype
=
self
.
dtype
or
dtype
if
dtype
==
"floatX"
:
dtype
=
config
.
floatX
elif
dtype
is
None
or
(
isinstance
(
dtype
,
str
)
and
dtype
not
in
all_dtypes
):
raise
TypeError
(
"dtype is unspecified"
)
if
isinstance
(
dtype
,
str
):
dtype_idx
=
constant
(
all_dtypes
.
index
(
dtype
),
dtype
=
"int64"
)
else
:
dtype_idx
=
constant
(
dtype
,
dtype
=
"int64"
)
dtype
=
all_dtypes
[
dtype_idx
.
data
]
inputs
=
(
rng
,
size
,
dtype_idx
,
*
dist_params
)
out_var
=
TensorType
(
dtype
=
dtype
,
shape
=
static_shape
)()
outputs
=
(
rng
.
type
(),
out_var
)
inputs
=
(
rng
,
size
,
*
dist_params
)
out_type
=
TensorType
(
dtype
=
self
.
dtype
,
shape
=
static_shape
)
outputs
=
(
rng
.
type
(),
out_type
())
return
Apply
(
self
,
inputs
,
outputs
)
...
...
@@ -382,14 +389,12 @@ class RandomVariable(Op):
def
dist_params
(
self
,
node
)
->
Sequence
[
Variable
]:
"""Return the node inpust corresponding to dist params"""
return
node
.
inputs
[
3
:]
return
node
.
inputs
[
2
:]
def
perform
(
self
,
node
,
inputs
,
outputs
):
rng_var_out
,
smpl_out
=
outputs
rng
,
size
,
dtype
,
*
args
=
inputs
out_var
=
node
.
outputs
[
1
]
rng
,
size
,
*
args
=
inputs
# If `size == []`, that means no size is enforced, and NumPy is trusted
# to draw the appropriate number of samples, NumPy uses `size=None` to
...
...
@@ -408,11 +413,8 @@ class RandomVariable(Op):
smpl_val
=
self
.
rng_fn
(
rng
,
*
([
*
args
,
size
]))
if
(
not
isinstance
(
smpl_val
,
np
.
ndarray
)
or
str
(
smpl_val
.
dtype
)
!=
out_var
.
type
.
dtype
):
smpl_val
=
_asarray
(
smpl_val
,
dtype
=
out_var
.
type
.
dtype
)
if
not
isinstance
(
smpl_val
,
np
.
ndarray
)
or
str
(
smpl_val
.
dtype
)
!=
self
.
dtype
:
smpl_val
=
_asarray
(
smpl_val
,
dtype
=
self
.
dtype
)
smpl_out
[
0
]
=
smpl_val
...
...
@@ -463,7 +465,7 @@ default_rng = DefaultGeneratorMakerOp()
@_vectorize_node.register
(
RandomVariable
)
def
vectorize_random_variable
(
op
:
RandomVariable
,
node
:
Apply
,
rng
,
size
,
dtype
,
*
dist_params
op
:
RandomVariable
,
node
:
Apply
,
rng
,
size
,
*
dist_params
)
->
Apply
:
# If size was provided originally and a new size hasn't been provided,
# We extend it to accommodate the new input batch dimensions.
...
...
@@ -491,4 +493,4 @@ def vectorize_random_variable(
new_size_dims
=
broadcasted_batch_shape
[:
new_ndim
]
size
=
concatenate
([
new_size_dims
,
size
])
return
op
.
make_node
(
rng
,
size
,
dtype
,
*
dist_params
)
return
op
.
make_node
(
rng
,
size
,
*
dist_params
)
pytensor/tensor/random/rewriting/basic.py
浏览文件 @
98d73d78
...
...
@@ -81,7 +81,7 @@ def local_rv_size_lift(fgraph, node):
if
not
isinstance
(
node
.
op
,
RandomVariable
):
return
rng
,
size
,
dtype
,
*
dist_params
=
node
.
inputs
rng
,
size
,
*
dist_params
=
node
.
inputs
dist_params
=
broadcast_params
(
dist_params
,
node
.
op
.
ndims_params
)
...
...
@@ -105,7 +105,7 @@ def local_rv_size_lift(fgraph, node):
else
:
return
new_node
=
node
.
op
.
make_node
(
rng
,
None
,
dtype
,
*
dist_params
)
new_node
=
node
.
op
.
make_node
(
rng
,
None
,
*
dist_params
)
if
config
.
compute_test_value
!=
"off"
:
compute_test_value
(
new_node
)
...
...
@@ -141,7 +141,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
return
False
rv_op
=
rv_node
.
op
rng
,
size
,
dtype
,
*
dist_params
=
rv_node
.
inputs
rng
,
size
,
*
dist_params
=
rv_node
.
inputs
rv
=
rv_node
.
default_output
()
# Check that Dimshuffle does not affect support dims
...
...
@@ -185,7 +185,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
)
new_dist_params
.
append
(
param
.
dimshuffle
(
param_new_order
))
new_node
=
rv_op
.
make_node
(
rng
,
new_size
,
dtype
,
*
new_dist_params
)
new_node
=
rv_op
.
make_node
(
rng
,
new_size
,
*
new_dist_params
)
if
config
.
compute_test_value
!=
"off"
:
compute_test_value
(
new_node
)
...
...
@@ -233,7 +233,7 @@ def local_subtensor_rv_lift(fgraph, node):
return
None
rv_op
=
rv_node
.
op
rng
,
size
,
dtype
,
*
dist_params
=
rv_node
.
inputs
rng
,
size
,
*
dist_params
=
rv_node
.
inputs
# Parse indices
idx_list
=
getattr
(
subtensor_op
,
"idx_list"
,
None
)
...
...
@@ -346,7 +346,7 @@ def local_subtensor_rv_lift(fgraph, node):
new_dist_params
.
append
(
batch_param
[
tuple
(
batch_indices
)])
# Create new RV
new_node
=
rv_op
.
make_node
(
rng
,
new_size
,
dtype
,
*
new_dist_params
)
new_node
=
rv_op
.
make_node
(
rng
,
new_size
,
*
new_dist_params
)
new_rv
=
new_node
.
default_output
()
copy_stack_trace
(
rv
,
new_rv
)
...
...
tests/tensor/random/rewriting/test_basic.py
浏览文件 @
98d73d78
...
...
@@ -12,6 +12,7 @@ from pytensor.graph.rewriting.db import RewriteDatabaseQuery
from
pytensor.tensor
import
constant
from
pytensor.tensor.elemwise
import
DimShuffle
from
pytensor.tensor.random.basic
import
(
NormalRV
,
categorical
,
dirichlet
,
multinomial
,
...
...
@@ -397,7 +398,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
)
if
lifted
:
assert
new_out
.
owner
.
op
==
dist_op
assert
isinstance
(
new_out
.
owner
.
op
,
type
(
dist_op
))
assert
all
(
isinstance
(
i
.
owner
.
op
,
DimShuffle
)
for
i
in
new_out
.
owner
.
op
.
dist_params
(
new_out
.
owner
)
...
...
@@ -832,7 +833,7 @@ def test_Subtensor_lift_restrictions():
subtensor_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
assert
subtensor_node
==
y
.
owner
assert
isinstance
(
subtensor_node
.
op
,
Subtensor
)
assert
subtensor_node
.
inputs
[
0
]
.
owner
.
op
==
normal
assert
isinstance
(
subtensor_node
.
inputs
[
0
]
.
owner
.
op
,
NormalRV
)
z
=
pt
.
ones
(
x
.
shape
)
-
x
[
1
]
...
...
@@ -850,7 +851,7 @@ def test_Subtensor_lift_restrictions():
EquilibriumGraphRewriter
([
local_subtensor_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
rv_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
.
inputs
[
0
]
.
owner
assert
rv_node
.
op
==
normal
assert
isinstance
(
rv_node
.
op
,
NormalRV
)
assert
isinstance
(
rv_node
.
inputs
[
-
1
]
.
owner
.
op
,
Subtensor
)
assert
isinstance
(
rv_node
.
inputs
[
-
2
]
.
owner
.
op
,
Subtensor
)
...
...
@@ -872,7 +873,7 @@ def test_Dimshuffle_lift_restrictions():
dimshuffle_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
assert
dimshuffle_node
==
y
.
owner
assert
isinstance
(
dimshuffle_node
.
op
,
DimShuffle
)
assert
dimshuffle_node
.
inputs
[
0
]
.
owner
.
op
==
normal
assert
isinstance
(
dimshuffle_node
.
inputs
[
0
]
.
owner
.
op
,
NormalRV
)
z
=
pt
.
ones
(
x
.
shape
)
-
y
...
...
@@ -890,7 +891,7 @@ def test_Dimshuffle_lift_restrictions():
EquilibriumGraphRewriter
([
local_dimshuffle_rv_lift
],
max_use_ratio
=
100
)
.
apply
(
fg
)
rv_node
=
fg
.
outputs
[
0
]
.
owner
.
inputs
[
1
]
.
owner
assert
rv_node
.
op
==
normal
assert
isinstance
(
rv_node
.
op
,
NormalRV
)
assert
isinstance
(
rv_node
.
inputs
[
-
1
]
.
owner
.
op
,
DimShuffle
)
assert
isinstance
(
rv_node
.
inputs
[
-
2
]
.
owner
.
op
,
DimShuffle
)
...
...
tests/tensor/random/test_op.py
浏览文件 @
98d73d78
...
...
@@ -3,14 +3,14 @@ import pytest
import
pytensor.tensor
as
pt
from
pytensor
import
config
,
function
from
pytensor.gradient
import
NullTypeGradError
,
grad
from
pytensor.graph.replace
import
vectorize_node
from
pytensor.graph.replace
import
vectorize_graph
from
pytensor.raise_op
import
Assert
from
pytensor.tensor.math
import
eq
from
pytensor.tensor.random
import
normal
from
pytensor.tensor.random.basic
import
NormalRV
from
pytensor.tensor.random.op
import
RandomState
,
RandomVariable
,
default_rng
from
pytensor.tensor.shape
import
specify_shape
from
pytensor.tensor.type
import
all_dtypes
,
iscalar
,
tensor
from
pytensor.tensor.type
import
iscalar
,
tensor
@pytest.fixture
(
scope
=
"function"
,
autouse
=
False
)
...
...
@@ -51,15 +51,6 @@ def test_RandomVariable_basics(strict_test_value_flags):
inplace
=
True
,
)(
0
,
1
,
size
=
{
1
,
2
})
# No dtype
with
pytest
.
raises
(
TypeError
,
match
=
"^dtype*"
):
RandomVariable
(
"normal"
,
0
,
[
0
,
0
],
inplace
=
True
,
)(
0
,
1
)
# Confirm that `inplace` works
rv
=
RandomVariable
(
"normal"
,
...
...
@@ -80,16 +71,19 @@ def test_RandomVariable_basics(strict_test_value_flags):
rv_shape
=
rv
.
_infer_shape
(
pt
.
constant
([]),
(),
[])
assert
rv_shape
.
equals
(
pt
.
constant
([],
dtype
=
"int64"
))
#
Integer-specified `dtype`
dtype_1
=
all_dtypes
[
1
]
rv_node
=
rv
.
make_node
(
None
,
None
,
1
)
rv_out
=
rv_node
.
outputs
[
1
]
rv_out
.
tag
.
test_value
=
1
assert
rv_out
.
dtype
==
dtype_1
#
`dtype` is respected
rv
=
RandomVariable
(
"normal"
,
signature
=
"(),()->()"
,
dtype
=
"int32"
)
with
config
.
change_flags
(
compute_test_value
=
"off"
):
rv_out
=
rv
()
assert
rv_out
.
dtype
==
"int32"
rv_out
=
rv
(
dtype
=
"int64"
)
assert
rv_out
.
dtype
==
"int64"
with
pytest
.
raises
(
NullTypeGradError
):
grad
(
rv_out
,
[
rv_node
.
inputs
[
0
]])
with
pytest
.
raises
(
ValueError
,
match
=
"Cannot change the dtype of a normal RV from int32 to float32"
,
):
assert
rv
(
dtype
=
"float32"
)
.
dtype
==
"float32"
def
test_RandomVariable_bcast
(
strict_test_value_flags
):
...
...
@@ -238,70 +232,70 @@ def test_multivariate_rv_infer_static_shape():
assert
mv_op
(
param1
,
param2
,
size
=
(
10
,
2
))
.
type
.
shape
==
(
10
,
2
,
3
)
def
test_vectorize
_node
():
def
test_vectorize
():
vec
=
tensor
(
shape
=
(
None
,))
mat
=
tensor
(
shape
=
(
None
,
None
))
# Test without size
node
=
normal
(
vec
)
.
owner
new_inputs
=
node
.
inputs
.
copy
()
new_inputs
[
3
]
=
mat
# mu
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
assert
vect_node
.
op
is
normal
assert
vect_node
.
inputs
[
3
]
is
mat
out
=
normal
(
vec
)
vect_node
=
vectorize_graph
(
out
,
{
vec
:
mat
})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
vect_node
.
op
.
dist_params
(
vect_node
)[
0
]
is
mat
# Test with size, new size provided
node
=
normal
(
vec
,
size
=
(
3
,))
.
owner
new_inputs
=
node
.
inputs
.
copy
()
new_inputs
[
1
]
=
(
2
,
3
)
# size
new_inputs
[
3
]
=
mat
# mu
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
assert
vect_node
.
op
is
normal
assert
tuple
(
vect_node
.
inputs
[
1
]
.
eval
())
==
(
2
,
3
)
assert
vect_node
.
inputs
[
3
]
is
mat
size
=
pt
.
as_tensor
(
np
.
array
((
3
,),
dtype
=
"int64"
))
out
=
normal
(
vec
,
size
=
size
)
vect_node
=
vectorize_graph
(
out
,
{
vec
:
mat
,
size
:
(
2
,
3
)})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
tuple
(
vect_node
.
op
.
size_param
(
vect_node
)
.
eval
())
==
(
2
,
3
)
assert
vect_node
.
op
.
dist_params
(
vect_node
)[
0
]
is
mat
# Test with size, new size not provided
node
=
normal
(
vec
,
size
=
(
3
,))
.
owner
new_inputs
=
node
.
inputs
.
copy
()
new_inputs
[
3
]
=
mat
# mu
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
assert
vect_node
.
op
is
normal
assert
vect_node
.
inputs
[
3
]
is
mat
out
=
normal
(
vec
,
size
=
(
3
,))
vect_node
=
vectorize_graph
(
out
,
{
vec
:
mat
})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
vect_node
.
op
.
dist_params
(
vect_node
)[
0
]
is
mat
assert
tuple
(
vect_node
.
inputs
[
1
]
.
eval
({
mat
:
np
.
zeros
((
2
,
3
),
dtype
=
config
.
floatX
)})
vect_node
.
op
.
size_param
(
vect_node
)
.
eval
(
{
mat
:
np
.
zeros
((
2
,
3
),
dtype
=
config
.
floatX
)}
)
)
==
(
2
,
3
)
# Test parameter broadcasting
node
=
normal
(
vec
)
.
owner
new_inputs
=
node
.
inputs
.
copy
()
new_inputs
[
3
]
=
tensor
(
"mu"
,
shape
=
(
10
,
5
))
# mu
new_inputs
[
4
]
=
tensor
(
"sigma"
,
shape
=
(
10
,))
# sigma
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
assert
vect_node
.
op
is
normal
mu
=
vec
sigma
=
pt
.
as_tensor
(
np
.
array
(
1.0
))
out
=
normal
(
mu
,
sigma
)
new_mu
=
tensor
(
"mu"
,
shape
=
(
10
,
5
))
new_sigma
=
tensor
(
"sigma"
,
shape
=
(
10
,))
vect_node
=
vectorize_graph
(
out
,
{
mu
:
new_mu
,
sigma
:
new_sigma
})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
5
)
# Test parameter broadcasting with non-expanding size
node
=
normal
(
vec
,
size
=
(
5
,))
.
owner
new_inputs
=
node
.
inputs
.
copy
()
new_inputs
[
3
]
=
tensor
(
"mu"
,
shape
=
(
10
,
5
))
# mu
new_inputs
[
4
]
=
tensor
(
"sigma"
,
shape
=
(
10
,))
# sigma
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
assert
vect_node
.
op
is
normal
mu
=
vec
sigma
=
pt
.
as_tensor
(
np
.
array
(
1.0
))
out
=
normal
(
mu
,
sigma
,
size
=
(
5
,))
new_mu
=
tensor
(
"mu"
,
shape
=
(
10
,
5
))
new_sigma
=
tensor
(
"sigma"
,
shape
=
(
10
,))
vect_node
=
vectorize_graph
(
out
,
{
mu
:
new_mu
,
sigma
:
new_sigma
})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
5
)
node
=
normal
(
vec
,
size
=
(
5
,))
.
owner
new_inputs
=
node
.
inputs
.
copy
()
new_inputs
[
3
]
=
tensor
(
"mu"
,
shape
=
(
1
,
5
))
# mu
new_inputs
[
4
]
=
tensor
(
"sigma"
,
shape
=
(
10
,))
# sigma
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
assert
vect_node
.
op
is
normal
mu
=
vec
sigma
=
pt
.
as_tensor
(
np
.
array
(
1.0
))
out
=
normal
(
mu
,
sigma
,
size
=
(
5
,))
new_mu
=
tensor
(
"mu"
,
shape
=
(
1
,
5
))
# mu
new_sigma
=
tensor
(
"sigma"
,
shape
=
(
10
,))
# sigma
vect_node
=
vectorize_graph
(
out
,
{
mu
:
new_mu
,
sigma
:
new_sigma
})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
5
)
# Test parameter broadcasting with expanding size
node
=
normal
(
vec
,
size
=
(
2
,
5
))
.
owner
new_inputs
=
node
.
inputs
.
copy
()
new_inputs
[
3
]
=
tensor
(
"mu"
,
shape
=
(
10
,
5
))
# mu
new_inputs
[
4
]
=
tensor
(
"sigma"
,
shape
=
(
10
,))
# sigma
vect_node
=
vectorize_node
(
node
,
*
new_inputs
)
assert
vect_node
.
op
is
normal
mu
=
vec
sigma
=
pt
.
as_tensor
(
np
.
array
(
1.0
))
out
=
normal
(
mu
,
sigma
,
size
=
(
2
,
5
))
new_mu
=
tensor
(
"mu"
,
shape
=
(
1
,
5
))
new_sigma
=
tensor
(
"sigma"
,
shape
=
(
10
,))
vect_node
=
vectorize_graph
(
out
,
{
mu
:
new_mu
,
sigma
:
new_sigma
})
.
owner
assert
isinstance
(
vect_node
.
op
,
NormalRV
)
assert
vect_node
.
default_output
()
.
type
.
shape
==
(
10
,
2
,
5
)
编写
预览
Markdown
格式
0%
重试
或
添加新文件
添加附件
取消
您添加了
0
人
到此讨论。请谨慎行事。
请先完成此评论的编辑!
取消
请
注册
或者
登录
后发表评论