提交 d9fd640e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use constant folding to determine RandomVariable broadcast dimensions

上级 bbd4e643
......@@ -4,20 +4,14 @@ from copy import copy
import numpy as np
import aesara
from aesara.assert_op import Assert
from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op
from aesara.graph.opt_utils import optimize_graph
from aesara.misc.safe_asarray import _asarray
from aesara.scalar.basic import Cast
from aesara.tensor.basic import (
as_tensor_variable,
constant,
get_scalar_constant_value,
get_vector_length,
)
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.basic import as_tensor_variable, constant, get_vector_length
from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding
from aesara.tensor.random.type import RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.shape import shape_tuple
......@@ -287,30 +281,16 @@ class RandomVariable(Op):
"""
shape = self._infer_shape(size, dist_params)
# Ignore `Cast`s, since they do not affect broadcastables
if getattr(shape, "owner", None) and (
isinstance(shape.owner.op, Elemwise)
and isinstance(shape.owner.op.scalar_op, Cast)
):
shape = shape.owner.inputs[0]
# Let's try to do a better job than `_infer_ndim_bcast` when
# dimension sizes are symbolic.
bcast = []
for s in shape:
s_owner = getattr(s, "owner", None)
# Get rid of the `Assert`s added by `broadcast_shape`
if s_owner and isinstance(s_owner.op, Assert):
s = s_owner.inputs[0]
try:
s_val = get_scalar_constant_value(s)
except NotScalarConstantError:
s_val = False
shape_fg = FunctionGraph(
outputs=[as_tensor_variable(s, ndim=0) for s in shape],
features=[ShapeFeature()],
clone=True,
)
folded_shape = optimize_graph(
shape_fg, custom_opt=topo_constant_folding
).outputs
bcast += [s_val == 1]
return bcast
return [getattr(s, "data", s) == 1 for s in folded_shape]
def infer_shape(self, fgraph, node, input_shapes):
_, size, _, *dist_params = node.inputs
......
......@@ -135,6 +135,9 @@ def test_RandomVariable_bcast():
res = rv.compute_bcast([mu, sd], size)
assert res == [True, False, False]
res = rv(0, 1, size=aet.as_tensor(1, dtype=np.int64))
assert res.broadcastable == (True,)
def test_RandomVariable_floatX():
test_rv_op = RandomVariable(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论