提交 d9fd640e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use constant folding to determine RandomVariable broadcast dimensions

上级 bbd4e643
...@@ -4,20 +4,14 @@ from copy import copy ...@@ -4,20 +4,14 @@ from copy import copy
import numpy as np import numpy as np
import aesara import aesara
from aesara.assert_op import Assert
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.graph.opt_utils import optimize_graph
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.scalar.basic import Cast from aesara.tensor.basic import as_tensor_variable, constant, get_vector_length
from aesara.tensor.basic import ( from aesara.tensor.basic_opt import ShapeFeature, topo_constant_folding
as_tensor_variable,
constant,
get_scalar_constant_value,
get_vector_length,
)
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.random.type import RandomType from aesara.tensor.random.type import RandomType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
from aesara.tensor.shape import shape_tuple from aesara.tensor.shape import shape_tuple
...@@ -287,30 +281,16 @@ class RandomVariable(Op): ...@@ -287,30 +281,16 @@ class RandomVariable(Op):
""" """
shape = self._infer_shape(size, dist_params) shape = self._infer_shape(size, dist_params)
# Ignore `Cast`s, since they do not affect broadcastables shape_fg = FunctionGraph(
if getattr(shape, "owner", None) and ( outputs=[as_tensor_variable(s, ndim=0) for s in shape],
isinstance(shape.owner.op, Elemwise) features=[ShapeFeature()],
and isinstance(shape.owner.op.scalar_op, Cast) clone=True,
): )
shape = shape.owner.inputs[0] folded_shape = optimize_graph(
shape_fg, custom_opt=topo_constant_folding
# Let's try to do a better job than `_infer_ndim_bcast` when ).outputs
# dimension sizes are symbolic.
bcast = []
for s in shape:
s_owner = getattr(s, "owner", None)
# Get rid of the `Assert`s added by `broadcast_shape`
if s_owner and isinstance(s_owner.op, Assert):
s = s_owner.inputs[0]
try:
s_val = get_scalar_constant_value(s)
except NotScalarConstantError:
s_val = False
bcast += [s_val == 1] return [getattr(s, "data", s) == 1 for s in folded_shape]
return bcast
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
_, size, _, *dist_params = node.inputs _, size, _, *dist_params = node.inputs
......
...@@ -135,6 +135,9 @@ def test_RandomVariable_bcast(): ...@@ -135,6 +135,9 @@ def test_RandomVariable_bcast():
res = rv.compute_bcast([mu, sd], size) res = rv.compute_bcast([mu, sd], size)
assert res == [True, False, False] assert res == [True, False, False]
res = rv(0, 1, size=aet.as_tensor(1, dtype=np.int64))
assert res.broadcastable == (True,)
def test_RandomVariable_floatX(): def test_RandomVariable_floatX():
test_rv_op = RandomVariable( test_rv_op = RandomVariable(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论