提交 05093821 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix RandomVariable rewrite failures arising from "output" strings

上级 f627b639
......@@ -8,7 +8,7 @@ from aesara.tensor.extra_ops import broadcast_to
from aesara.tensor.math import sum as aet_sum
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.utils import broadcast_params
from aesara.tensor.shape import Shape
from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import (
AdvancedSubtensor,
AdvancedSubtensor1,
......@@ -19,6 +19,26 @@ from aesara.tensor.subtensor import (
)
def is_rv_used_in_graph(base_rv, node, fgraph):
"""Determine whether or not `base_rv` is used by a node other than `node` in `fgraph`.
If a node uses `Shape` or `Shape_i` on the `base_rv`, we ignore it, because
those `Op`s don't rely on the actual sample values of `base_rv`.
TODO: We should apply all the shape rewrites before these rewrites, since
that would properly remove the unnecessary dependencies on `base_rv` (when
possible).
"""
def _node_check(n, i):
if n == "output":
n = fgraph.outputs[i].owner
return n == node or isinstance(n.op, (Shape, Shape_i))
return not all(_node_check(n, i) for n, i in fgraph.clients.get(base_rv, ()))
@local_optimizer([RandomVariable], inplace=True)
def random_make_inplace(fgraph, node):
op = node.op
......@@ -118,10 +138,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if not all(
(n == node or isinstance(n.op, Shape))
for n, i in fgraph.clients.get(base_rv, ())
):
if is_rv_used_in_graph(base_rv, node, fgraph):
return False
rv_op = rv_node.op
......@@ -273,10 +290,7 @@ def local_subtensor_rv_lift(fgraph, node):
# If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent.
if not all(
(n == node or isinstance(n.op, Shape))
for n, i in fgraph.clients.get(base_rv, ())
):
if is_rv_used_in_graph(base_rv, node, fgraph):
return False
rv_op = rv_node.op
......
......@@ -463,10 +463,18 @@ def test_Subtensor_lift_restrictions():
assert isinstance(subtensor_node.op, Subtensor)
assert subtensor_node.inputs[0].owner.op == normal
# The non-`Subtensor` client doesn't depend on the RNG state, so we can
# perform the lift
z = aet.ones(x.shape) - x[1]
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly.
fg = FunctionGraph([rng], [z, x], clone=False)
EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
assert fg.outputs[0] == z
assert fg.outputs[1] == x
# The non-`Subtensor` client doesn't depend on the RNG state, so we can
# perform the lift
fg = FunctionGraph([rng], [z], clone=False)
EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
......@@ -485,7 +493,7 @@ def test_Dimshuffle_lift_restrictions():
# perform the lift
z = x - y
fg = FunctionGraph([rng], [z], clone=False)
fg = FunctionGraph([rng], [z, y], clone=False)
_ = EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
dimshuffle_node = fg.outputs[0].owner.inputs[1].owner
......@@ -493,10 +501,18 @@ def test_Dimshuffle_lift_restrictions():
assert isinstance(dimshuffle_node.op, DimShuffle)
assert dimshuffle_node.inputs[0].owner.op == normal
# The non-`Dimshuffle` client doesn't depend on the RNG state, so we can
# perform the lift
z = aet.ones(x.shape) - y
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly.
fg = FunctionGraph([rng], [z, x], clone=False)
EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
assert fg.outputs[0] == z
assert fg.outputs[1] == x
# The non-`Dimshuffle` client doesn't depend on the RNG state, so we can
# perform the lift
fg = FunctionGraph([rng], [z], clone=False)
EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论