提交 0a10de29 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify random infer_shape/ShapeFeature tests

上级 2ec735a6
......@@ -230,32 +230,17 @@ def test_beta_samples(a, b, size):
def test_normal_infer_shape(make_args):
M_pt = iscalar("M")
sd_pt = scalar("sd")
test_values = {M_pt: 3, sd_pt: np.array(1.0, dtype=config.floatX)}
M, sd, size = make_args(M_pt, sd_pt)
rv = normal(M, sd, size=size)
size_pt = rv.owner.op.size_param(rv.owner)
rv_shape = list(normal._infer_shape(size_pt, [M, sd], None))
all_args = (M, sd, *(() if size is None else size))
fn_inputs = [
i
for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
if not isinstance(i, Constant | SharedVariable)
]
pytensor_fn = function(
fn_inputs, [pt.as_tensor(o) for o in [*rv_shape, rv]], mode=py_mode
)
size_from_node = rv.owner.op.size_param(rv.owner)
params_from_node = rv.owner.op.dist_params(rv.owner)
rv_shape = pt.as_tensor(normal._infer_shape(size_from_node, params_from_node, None))
*rv_shape_val, rv_val = pytensor_fn(
*[
test_values[i]
for i in fn_inputs
if not isinstance(i, SharedVariable | Constant)
]
pytensor_fn = function(
[M_pt, sd_pt], [rv_shape, rv], mode=py_mode, on_unused_input="ignore"
)
rv_shape_val, rv_val = pytensor_fn(M=3, sd=np.array(1.0, dtype=config.floatX))
assert tuple(rv_shape_val) == tuple(rv_val.shape)
......@@ -275,31 +260,16 @@ def test_normal_infer_shape(make_args):
],
)
def test_normal_infer_shape_params(M_val, sd_val, size):
M = pt.as_tensor_variable(M_val).type()
sd = pt.as_tensor_variable(sd_val).type()
M = pt.as_tensor_variable(M_val).type("M")
sd = pt.as_tensor_variable(sd_val).type("sd")
rv = normal(M, sd, size=size)
size_pt = rv.owner.op.size_param(rv.owner)
rv_shape = list(normal._infer_shape(size_pt, [M, sd], None))
all_args = (M, sd, *(() if size is None else size))
fn_inputs = [
i
for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
if not isinstance(i, Constant | SharedVariable)
]
pytensor_fn = function(
fn_inputs, [pt.as_tensor(o) for o in [*rv_shape, rv]], mode=py_mode
)
*rv_shape_val, rv_val = pytensor_fn(
*[
{M: M_val, sd: sd_val}[i]
for i in fn_inputs
if not isinstance(i, SharedVariable | Constant)
]
)
size_from_node = rv.owner.op.size_param(rv.owner)
params_from_node = rv.owner.op.dist_params(rv.owner)
rv_shape = pt.as_tensor(normal._infer_shape(size_from_node, params_from_node, None))
pytensor_fn = function([M, sd], [rv_shape, rv], mode=py_mode)
rv_shape_val, rv_val = pytensor_fn(M=M_val, sd=sd_val)
assert tuple(rv_shape_val) == tuple(rv_val.shape)
......@@ -310,8 +280,7 @@ def test_normal_ShapeFeature():
d_rv = normal(pt.ones((M_pt,)), sd_pt, size=(2, M_pt))
fg = FunctionGraph(
[i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
[d_rv],
outputs=[d_rv],
clone=False,
features=[ShapeFeature()],
)
......@@ -683,8 +652,7 @@ def test_mvnormal_ShapeFeature():
d_rv = multivariate_normal(pt.ones((M_pt,)), pt.eye(M_pt), size=2)
fg = FunctionGraph(
[i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)],
[d_rv],
outputs=[d_rv],
clone=False,
features=[ShapeFeature()],
)
......@@ -803,7 +771,7 @@ def test_dirichlet_rng():
@pytest.mark.parametrize(
"make_args",
"make_alpha_size",
[
lambda M_pt: (pt.ones((M_pt,)), None),
lambda M_pt: (pt.ones((M_pt,)), (M_pt + 1,)),
......@@ -813,34 +781,19 @@ def test_dirichlet_rng():
lambda M_pt: (pt.ones((M_pt, M_pt + 1)), (2, M_pt + 2, M_pt + 3, M_pt)),
],
)
def test_dirichlet_infer_shape(make_args):
M_pt = iscalar("M")
test_values = {M_pt: 3}
M, size = make_args(M_pt)
rv = dirichlet(M, size=size)
size_pt = rv.owner.op.size_param(rv.owner)
rv_shape = list(dirichlet._infer_shape(size_pt, [M], None))
all_args = (M, *(() if size is None else size))
fn_inputs = [
i
for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
if not isinstance(i, Constant | SharedVariable)
]
pytensor_fn = function(
fn_inputs, [pt.as_tensor(o) for o in [*rv_shape, rv]], mode=py_mode
)
*rv_shape_val, rv_val = pytensor_fn(
*[
test_values[i]
for i in fn_inputs
if not isinstance(i, SharedVariable | Constant)
]
def test_dirichlet_infer_shape(make_alpha_size):
M = iscalar("M")
alpha, size = make_alpha_size(M)
rv = dirichlet(alpha, size=size)
size_from_node = rv.owner.op.size_param(rv.owner)
params_from_node = rv.owner.op.dist_params(rv.owner)
rv_shape = pt.as_tensor(
dirichlet._infer_shape(size_from_node, params_from_node, None)
)
pytensor_fn = function([M], [rv_shape, rv], mode=py_mode)
rv_shape_val, rv_val = pytensor_fn(M=3)
assert tuple(rv_shape_val) == tuple(rv_val.shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论