提交 248ce6d9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Ignore Cast in RandomVariable.compute_bcast

Closes #390
上级 6228d023
...@@ -9,12 +9,14 @@ from aesara.configdefaults import config ...@@ -9,12 +9,14 @@ from aesara.configdefaults import config
from aesara.graph.basic import Apply, Variable from aesara.graph.basic import Apply, Variable
from aesara.graph.op import Op from aesara.graph.op import Op
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.scalar.basic import Cast
from aesara.tensor.basic import ( from aesara.tensor.basic import (
as_tensor_variable, as_tensor_variable,
constant, constant,
get_scalar_constant_value, get_scalar_constant_value,
get_vector_length, get_vector_length,
) )
from aesara.tensor.elemwise import Elemwise
from aesara.tensor.exceptions import NotScalarConstantError from aesara.tensor.exceptions import NotScalarConstantError
from aesara.tensor.random.type import RandomStateType from aesara.tensor.random.type import RandomStateType
from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes from aesara.tensor.random.utils import normalize_size_param, params_broadcast_shapes
...@@ -284,6 +286,13 @@ class RandomVariable(Op): ...@@ -284,6 +286,13 @@ class RandomVariable(Op):
""" """
shape = self._infer_shape(size, dist_params) shape = self._infer_shape(size, dist_params)
# Ignore `Cast`s, since they do not affect broadcastables
if getattr(shape, "owner", None) and (
isinstance(shape.owner.op, Elemwise)
and isinstance(shape.owner.op.scalar_op, Cast)
):
shape = shape.owner.inputs[0]
# Let's try to do a better job than `_infer_ndim_bcast` when # Let's try to do a better job than `_infer_ndim_bcast` when
# dimension sizes are symbolic. # dimension sizes are symbolic.
bcast = [] bcast = []
......
...@@ -111,6 +111,8 @@ def test_RandomVariable_basics(): ...@@ -111,6 +111,8 @@ def test_RandomVariable_basics():
with raises(NullTypeGradError): with raises(NullTypeGradError):
grad(rv_out, [rv_node.inputs[0]]) grad(rv_out, [rv_node.inputs[0]])
def test_RandomVariable_bcast():
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True) rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
mu = tensor(config.floatX, [True, False, False]) mu = tensor(config.floatX, [True, False, False])
...@@ -129,6 +131,10 @@ def test_RandomVariable_basics(): ...@@ -129,6 +131,10 @@ def test_RandomVariable_basics():
res = rv.compute_bcast([mu, sd], (s1, s2, s3)) res = rv.compute_bcast([mu, sd], (s1, s2, s3))
assert res == [False] * 3 assert res == [False] * 3
size = aet.as_tensor((1, 2, 3), dtype=np.int32).astype(np.int64)
res = rv.compute_bcast([mu, sd], size)
assert res == [True, False, False]
def test_RandomVariable_floatX(): def test_RandomVariable_floatX():
test_rv_op = RandomVariable( test_rv_op = RandomVariable(
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论