提交 2ca41998 authored 作者: Ricardo's avatar Ricardo 提交者: Brandon T. Willard

Propagate RV name during Dimshuffle_lift

上级 df7fcb21
......@@ -217,7 +217,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
if config.compute_test_value != "off":
compute_test_value(new_node)
return [new_node.outputs[1]]
out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]
ds_in_reps_space = ds_reps_new_dims and all(
d < reps_ind_split_idx for n, d in ds_reps_new_dims
......@@ -235,7 +238,10 @@ def local_dimshuffle_rv_lift(fgraph, node):
if config.compute_test_value != "off":
compute_test_value(new_node)
return [new_node.outputs[1]]
out = new_node.outputs[1]
if base_rv.name:
out.name = f"{base_rv.name}_lifted"
return [out]
return False
......
......@@ -30,7 +30,7 @@ from aesara.tensor.type import iscalar, vector
no_mode = Mode("py", OptimizationQuery(include=[], exclude=[]))
def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng):
def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng, name=None):
dist_params_aet = []
for p in dist_params:
p_aet = aet.as_tensor(p).type()
......@@ -43,7 +43,7 @@ def apply_local_opt_to_rv(opt, op_fn, dist_op, dist_params, size, rng):
s_aet.tag.test_value = s
size_aet.append(s_aet)
dist_st = op_fn(dist_op(*dist_params_aet, size=size_aet, rng=rng))
dist_st = op_fn(dist_op(*dist_params_aet, size=size_aet, rng=rng, name=name))
f_inputs = [
p for p in dist_params_aet + size_aet if not isinstance(p, (slice, Constant))
......@@ -521,3 +521,44 @@ def test_Dimshuffle_lift_restrictions():
assert rv_node.op == normal
assert isinstance(rv_node.inputs[-1].owner.op, DimShuffle)
assert isinstance(rv_node.inputs[-2].owner.op, DimShuffle)
@pytest.mark.parametrize(
"ds_order, lifted, dist_op, dist_params, size, rtol",
[
(
("x",),
True,
normal,
(
np.array(-10.0, dtype=np.float64),
np.array(1e-6, dtype=np.float64),
),
(),
1e-7,
),
(
(0, 1, 2),
True,
normal,
(np.array(0).astype(config.floatX), np.array(1e-6).astype(config.floatX)),
(2, 1, 2),
1e-3,
),
],
)
def test_Dimshuffle_lift_rename(ds_order, lifted, dist_op, dist_params, size, rtol):
rng = shared(np.random.default_rng(1233532), borrow=False)
new_out, *_ = apply_local_opt_to_rv(
local_dimshuffle_rv_lift,
lambda rv: rv.dimshuffle(ds_order),
dist_op,
dist_params,
size,
rng,
name="test_name",
)
assert new_out.name == "test_name_lifted"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论