提交 05093821 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Fix RandomVariable rewrite failures arising from "output" strings

上级 f627b639
...@@ -8,7 +8,7 @@ from aesara.tensor.extra_ops import broadcast_to ...@@ -8,7 +8,7 @@ from aesara.tensor.extra_ops import broadcast_to
from aesara.tensor.math import sum as aet_sum from aesara.tensor.math import sum as aet_sum
from aesara.tensor.random.op import RandomVariable from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.utils import broadcast_params from aesara.tensor.random.utils import broadcast_params
from aesara.tensor.shape import Shape from aesara.tensor.shape import Shape, Shape_i
from aesara.tensor.subtensor import ( from aesara.tensor.subtensor import (
AdvancedSubtensor, AdvancedSubtensor,
AdvancedSubtensor1, AdvancedSubtensor1,
...@@ -19,6 +19,26 @@ from aesara.tensor.subtensor import ( ...@@ -19,6 +19,26 @@ from aesara.tensor.subtensor import (
) )
def is_rv_used_in_graph(base_rv, node, fgraph):
"""Determine whether or not `base_rv` is used by a node other than `node` in `fgraph`.
If a node uses `Shape` or `Shape_i` on the `base_rv`, we ignore it, because
those `Op`s don't rely on the actual sample values of `base_rv`.
TODO: We should apply all the shape rewrites before these rewrites, since
that would properly remove the unnecessary dependencies on `base_rv` (when
possible).
"""
def _node_check(n, i):
if n == "output":
n = fgraph.outputs[i].owner
return n == node or isinstance(n.op, (Shape, Shape_i))
return not all(_node_check(n, i) for n, i in fgraph.clients.get(base_rv, ()))
@local_optimizer([RandomVariable], inplace=True) @local_optimizer([RandomVariable], inplace=True)
def random_make_inplace(fgraph, node): def random_make_inplace(fgraph, node):
op = node.op op = node.op
...@@ -118,10 +138,7 @@ def local_dimshuffle_rv_lift(fgraph, node): ...@@ -118,10 +138,7 @@ def local_dimshuffle_rv_lift(fgraph, node):
# If no one else is using the underlying `RandomVariable`, then we can # If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent. # do this; otherwise, the graph would be internally inconsistent.
if not all( if is_rv_used_in_graph(base_rv, node, fgraph):
(n == node or isinstance(n.op, Shape))
for n, i in fgraph.clients.get(base_rv, ())
):
return False return False
rv_op = rv_node.op rv_op = rv_node.op
...@@ -273,10 +290,7 @@ def local_subtensor_rv_lift(fgraph, node): ...@@ -273,10 +290,7 @@ def local_subtensor_rv_lift(fgraph, node):
# If no one else is using the underlying `RandomVariable`, then we can # If no one else is using the underlying `RandomVariable`, then we can
# do this; otherwise, the graph would be internally inconsistent. # do this; otherwise, the graph would be internally inconsistent.
if not all( if is_rv_used_in_graph(base_rv, node, fgraph):
(n == node or isinstance(n.op, Shape))
for n, i in fgraph.clients.get(base_rv, ())
):
return False return False
rv_op = rv_node.op rv_op = rv_node.op
......
...@@ -463,10 +463,18 @@ def test_Subtensor_lift_restrictions(): ...@@ -463,10 +463,18 @@ def test_Subtensor_lift_restrictions():
assert isinstance(subtensor_node.op, Subtensor) assert isinstance(subtensor_node.op, Subtensor)
assert subtensor_node.inputs[0].owner.op == normal assert subtensor_node.inputs[0].owner.op == normal
# The non-`Subtensor` client doesn't depend on the RNG state, so we can
# perform the lift
z = aet.ones(x.shape) - x[1] z = aet.ones(x.shape) - x[1]
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly.
fg = FunctionGraph([rng], [z, x], clone=False)
EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
assert fg.outputs[0] == z
assert fg.outputs[1] == x
# The non-`Subtensor` client doesn't depend on the RNG state, so we can
# perform the lift
fg = FunctionGraph([rng], [z], clone=False) fg = FunctionGraph([rng], [z], clone=False)
EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg) EquilibriumOptimizer([local_subtensor_rv_lift], max_use_ratio=100).apply(fg)
...@@ -485,7 +493,7 @@ def test_Dimshuffle_lift_restrictions(): ...@@ -485,7 +493,7 @@ def test_Dimshuffle_lift_restrictions():
# perform the lift # perform the lift
z = x - y z = x - y
fg = FunctionGraph([rng], [z], clone=False) fg = FunctionGraph([rng], [z, y], clone=False)
_ = EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) _ = EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
dimshuffle_node = fg.outputs[0].owner.inputs[1].owner dimshuffle_node = fg.outputs[0].owner.inputs[1].owner
...@@ -493,10 +501,18 @@ def test_Dimshuffle_lift_restrictions(): ...@@ -493,10 +501,18 @@ def test_Dimshuffle_lift_restrictions():
assert isinstance(dimshuffle_node.op, DimShuffle) assert isinstance(dimshuffle_node.op, DimShuffle)
assert dimshuffle_node.inputs[0].owner.op == normal assert dimshuffle_node.inputs[0].owner.op == normal
# The non-`Dimshuffle` client doesn't depend on the RNG state, so we can
# perform the lift
z = aet.ones(x.shape) - y z = aet.ones(x.shape) - y
# We add `x` as an output to make sure that `is_rv_used_in_graph` handles
# `"output"` "nodes" correctly.
fg = FunctionGraph([rng], [z, x], clone=False)
EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
assert fg.outputs[0] == z
assert fg.outputs[1] == x
# The non-`Dimshuffle` client doesn't depend on the RNG state, so we can
# perform the lift
fg = FunctionGraph([rng], [z], clone=False) fg = FunctionGraph([rng], [z], clone=False)
EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg) EquilibriumOptimizer([local_dimshuffle_rv_lift], max_use_ratio=100).apply(fg)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论