提交 0a10de29 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify random infer_shape/ShapeFeature tests

上级 2ec735a6
...@@ -230,32 +230,17 @@ def test_beta_samples(a, b, size): ...@@ -230,32 +230,17 @@ def test_beta_samples(a, b, size):
def test_normal_infer_shape(make_args): def test_normal_infer_shape(make_args):
M_pt = iscalar("M") M_pt = iscalar("M")
sd_pt = scalar("sd") sd_pt = scalar("sd")
test_values = {M_pt: 3, sd_pt: np.array(1.0, dtype=config.floatX)}
M, sd, size = make_args(M_pt, sd_pt) M, sd, size = make_args(M_pt, sd_pt)
rv = normal(M, sd, size=size) rv = normal(M, sd, size=size)
size_pt = rv.owner.op.size_param(rv.owner)
rv_shape = list(normal._infer_shape(size_pt, [M, sd], None))
all_args = (M, sd, *(() if size is None else size)) size_from_node = rv.owner.op.size_param(rv.owner)
fn_inputs = [ params_from_node = rv.owner.op.dist_params(rv.owner)
i rv_shape = pt.as_tensor(normal._infer_shape(size_from_node, params_from_node, None))
for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
if not isinstance(i, Constant | SharedVariable)
]
pytensor_fn = function(
fn_inputs, [pt.as_tensor(o) for o in [*rv_shape, rv]], mode=py_mode
)
*rv_shape_val, rv_val = pytensor_fn( pytensor_fn = function(
*[ [M_pt, sd_pt], [rv_shape, rv], mode=py_mode, on_unused_input="ignore"
test_values[i]
for i in fn_inputs
if not isinstance(i, SharedVariable | Constant)
]
) )
rv_shape_val, rv_val = pytensor_fn(M=3, sd=np.array(1.0, dtype=config.floatX))
assert tuple(rv_shape_val) == tuple(rv_val.shape) assert tuple(rv_shape_val) == tuple(rv_val.shape)
...@@ -275,31 +260,16 @@ def test_normal_infer_shape(make_args): ...@@ -275,31 +260,16 @@ def test_normal_infer_shape(make_args):
], ],
) )
def test_normal_infer_shape_params(M_val, sd_val, size): def test_normal_infer_shape_params(M_val, sd_val, size):
M = pt.as_tensor_variable(M_val).type() M = pt.as_tensor_variable(M_val).type("M")
sd = pt.as_tensor_variable(sd_val).type() sd = pt.as_tensor_variable(sd_val).type("sd")
rv = normal(M, sd, size=size) rv = normal(M, sd, size=size)
size_pt = rv.owner.op.size_param(rv.owner)
rv_shape = list(normal._infer_shape(size_pt, [M, sd], None))
all_args = (M, sd, *(() if size is None else size))
fn_inputs = [
i
for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
if not isinstance(i, Constant | SharedVariable)
]
pytensor_fn = function(
fn_inputs, [pt.as_tensor(o) for o in [*rv_shape, rv]], mode=py_mode
)
*rv_shape_val, rv_val = pytensor_fn( size_from_node = rv.owner.op.size_param(rv.owner)
*[ params_from_node = rv.owner.op.dist_params(rv.owner)
{M: M_val, sd: sd_val}[i] rv_shape = pt.as_tensor(normal._infer_shape(size_from_node, params_from_node, None))
for i in fn_inputs
if not isinstance(i, SharedVariable | Constant)
]
)
pytensor_fn = function([M, sd], [rv_shape, rv], mode=py_mode)
rv_shape_val, rv_val = pytensor_fn(M=M_val, sd=sd_val)
assert tuple(rv_shape_val) == tuple(rv_val.shape) assert tuple(rv_shape_val) == tuple(rv_val.shape)
...@@ -310,8 +280,7 @@ def test_normal_ShapeFeature(): ...@@ -310,8 +280,7 @@ def test_normal_ShapeFeature():
d_rv = normal(pt.ones((M_pt,)), sd_pt, size=(2, M_pt)) d_rv = normal(pt.ones((M_pt,)), sd_pt, size=(2, M_pt))
fg = FunctionGraph( fg = FunctionGraph(
[i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)], outputs=[d_rv],
[d_rv],
clone=False, clone=False,
features=[ShapeFeature()], features=[ShapeFeature()],
) )
...@@ -683,8 +652,7 @@ def test_mvnormal_ShapeFeature(): ...@@ -683,8 +652,7 @@ def test_mvnormal_ShapeFeature():
d_rv = multivariate_normal(pt.ones((M_pt,)), pt.eye(M_pt), size=2) d_rv = multivariate_normal(pt.ones((M_pt,)), pt.eye(M_pt), size=2)
fg = FunctionGraph( fg = FunctionGraph(
[i for i in graph_inputs([d_rv]) if not isinstance(i, Constant)], outputs=[d_rv],
[d_rv],
clone=False, clone=False,
features=[ShapeFeature()], features=[ShapeFeature()],
) )
...@@ -803,7 +771,7 @@ def test_dirichlet_rng(): ...@@ -803,7 +771,7 @@ def test_dirichlet_rng():
@pytest.mark.parametrize( @pytest.mark.parametrize(
"make_args", "make_alpha_size",
[ [
lambda M_pt: (pt.ones((M_pt,)), None), lambda M_pt: (pt.ones((M_pt,)), None),
lambda M_pt: (pt.ones((M_pt,)), (M_pt + 1,)), lambda M_pt: (pt.ones((M_pt,)), (M_pt + 1,)),
...@@ -813,34 +781,19 @@ def test_dirichlet_rng(): ...@@ -813,34 +781,19 @@ def test_dirichlet_rng():
lambda M_pt: (pt.ones((M_pt, M_pt + 1)), (2, M_pt + 2, M_pt + 3, M_pt)), lambda M_pt: (pt.ones((M_pt, M_pt + 1)), (2, M_pt + 2, M_pt + 3, M_pt)),
], ],
) )
def test_dirichlet_infer_shape(make_args): def test_dirichlet_infer_shape(make_alpha_size):
M_pt = iscalar("M") M = iscalar("M")
test_values = {M_pt: 3} alpha, size = make_alpha_size(M)
rv = dirichlet(alpha, size=size)
M, size = make_args(M_pt)
size_from_node = rv.owner.op.size_param(rv.owner)
rv = dirichlet(M, size=size) params_from_node = rv.owner.op.dist_params(rv.owner)
size_pt = rv.owner.op.size_param(rv.owner) rv_shape = pt.as_tensor(
rv_shape = list(dirichlet._infer_shape(size_pt, [M], None)) dirichlet._infer_shape(size_from_node, params_from_node, None)
all_args = (M, *(() if size is None else size))
fn_inputs = [
i
for i in graph_inputs([a for a in all_args if isinstance(a, Variable)])
if not isinstance(i, Constant | SharedVariable)
]
pytensor_fn = function(
fn_inputs, [pt.as_tensor(o) for o in [*rv_shape, rv]], mode=py_mode
)
*rv_shape_val, rv_val = pytensor_fn(
*[
test_values[i]
for i in fn_inputs
if not isinstance(i, SharedVariable | Constant)
]
) )
pytensor_fn = function([M], [rv_shape, rv], mode=py_mode)
rv_shape_val, rv_val = pytensor_fn(M=3)
assert tuple(rv_shape_val) == tuple(rv_val.shape) assert tuple(rv_shape_val) == tuple(rv_val.shape)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论