提交 448e5582 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Replace RNG update output in RV lift rewrites

Otherwise we end up with multiple RVs if the RNGs are an output / used elsewhere in the function
上级 2d81ccae
......@@ -130,6 +130,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
if ds_op.drop:
return False
[ds_rv] = node.outputs
rv_node = node.inputs[0].owner
if not (rv_node and isinstance(rv_node.op, RandomVariable)):
......@@ -182,10 +183,17 @@ def local_dimshuffle_rv_lift(fgraph, node):
if config.compute_test_value != "off":
compute_test_value(new_node)
new_rv = new_node.default_output()
new_next_rng, new_rv = new_node.outputs
if rv.name:
new_rv.name = f"{rv.name}_lifted"
return [new_rv]
# We replace uses of the dimshuffled RV by the new RV
# And uses of the old RNG update by the new RNG update
return {
ds_rv: new_rv,
next_rng: new_next_rng,
}
@node_rewriter([Subtensor, AdvancedSubtensor1, AdvancedSubtensor])
......@@ -217,7 +225,7 @@ def local_subtensor_rv_lift(fgraph, node):
rv_op = rv_node.op
rng, size, *dist_params = rv_node.inputs
rv = rv_node.default_output()
next_rng, rv = rv_node.outputs
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
......@@ -331,8 +339,13 @@ def local_subtensor_rv_lift(fgraph, node):
# Create new RV
new_node = rv_op.make_node(rng, new_size, *new_dist_params)
new_rv = new_node.default_output()
new_next_rng, new_rv = new_node.outputs
copy_stack_trace(rv, new_rv)
return [new_rv]
# We replace uses of the indexed RV by the new RV
# And uses of the old RNG update by the new RNG update
return {
indexed_rv: new_rv,
next_rng: new_next_rng,
}
from collections.abc import Sequence
import numpy as np
import pytest
......@@ -5,7 +7,7 @@ import pytensor.tensor as pt
from pytensor import config, shared
from pytensor.compile.function import function
from pytensor.compile.mode import Mode
from pytensor.graph.basic import Constant
from pytensor.graph.basic import Constant, Variable, ancestors
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.rewriting.basic import EquilibriumGraphRewriter
from pytensor.graph.rewriting.db import RewriteDatabaseQuery
......@@ -36,6 +38,16 @@ from pytensor.tensor.type_other import NoneConst
no_mode = Mode("py", RewriteDatabaseQuery(include=[], exclude=[]))
def count_rv_nodes_in_graph(outputs: Sequence[Variable]) -> int:
return len(
{
var.owner
for var in ancestors(outputs)
if var.owner and isinstance(var.owner.op, RandomVariable)
}
)
def apply_local_rewrite_to_rv(
rewrite, op_fn, dist_op, dist_params, size, rng, name=None
):
......@@ -58,7 +70,14 @@ def apply_local_rewrite_to_rv(
s_pt.tag.test_value = s
size_pt.append(s_pt)
dist_st = op_fn(dist_op(*dist_params_pt, size=size_pt, rng=rng, name=name))
next_rng, rv = dist_op(
*dist_params_pt, size=size_pt, rng=rng, name=name
).owner.outputs
dist_st = op_fn(rv)
assert (
count_rv_nodes_in_graph([dist_st, next_rng]) == 1
), "Function expects a single RV in the graph"
f_inputs = [
p
......@@ -72,13 +91,16 @@ def apply_local_rewrite_to_rv(
f_rewritten = function(
f_inputs,
dist_st,
[dist_st, next_rng],
mode=mode,
)
(new_out,) = f_rewritten.maker.fgraph.outputs
new_rv, new_next_rng = f_rewritten.maker.fgraph.outputs
assert (
count_rv_nodes_in_graph([new_rv, new_next_rng]) == 1
), "Rewritten should have a single RV in the graph"
return new_out, f_inputs, dist_st, f_rewritten
return new_rv, f_inputs, dist_st, f_rewritten
class TestRVExpraProps(RandomVariable):
......@@ -422,7 +444,7 @@ def test_DimShuffle_lift(ds_order, lifted, dist_op, dist_params, size, rtol):
arg_values = [p.get_test_value() for p in f_inputs]
res_base = f_base(*arg_values)
res_rewritten = f_rewritten(*arg_values)
res_rewritten, _ = f_rewritten(*arg_values)
np.testing.assert_allclose(res_base, res_rewritten, rtol=rtol)
......@@ -825,7 +847,7 @@ def test_Subtensor_lift(indices, lifted, dist_op, dist_params, size):
arg_values = [p.get_test_value() for p in f_inputs]
res_base = f_base(*arg_values)
res_rewritten = f_rewritten(*arg_values)
res_rewritten, _ = f_rewritten(*arg_values)
np.testing.assert_allclose(res_base, res_rewritten, rtol=1e-3, atol=1e-2)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论