提交 3d4ef668 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use SpecifyShape to track length of RandomVariable's size

上级 6bb459ad
...@@ -37,7 +37,7 @@ def make_numba_random_fn(node, np_random_func): ...@@ -37,7 +37,7 @@ def make_numba_random_fn(node, np_random_func):
argument to the Numba-supported scalar ``np.random`` functions. argument to the Numba-supported scalar ``np.random`` functions.
""" """
tuple_size = get_vector_length(node.inputs[1]) tuple_size = int(get_vector_length(node.inputs[1]))
size_dims = tuple_size - max(i.ndim for i in node.inputs[3:]) size_dims = tuple_size - max(i.ndim for i in node.inputs[3:])
# Make a broadcast-capable version of the Numba supported scalar sampling # Make a broadcast-capable version of the Numba supported scalar sampling
......
...@@ -6,9 +6,11 @@ import numpy as np ...@@ -6,9 +6,11 @@ import numpy as np
from aesara.compile.sharedvalue import shared from aesara.compile.sharedvalue import shared
from aesara.graph.basic import Variable from aesara.graph.basic import Variable
from aesara.tensor import get_vector_length
from aesara.tensor.basic import as_tensor_variable, cast, constant from aesara.tensor.basic import as_tensor_variable, cast, constant
from aesara.tensor.extra_ops import broadcast_to from aesara.tensor.extra_ops import broadcast_to
from aesara.tensor.math import maximum from aesara.tensor.math import maximum
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import int_dtypes from aesara.tensor.type import int_dtypes
...@@ -121,6 +123,9 @@ def normalize_size_param(size): ...@@ -121,6 +123,9 @@ def normalize_size_param(size):
) )
else: else:
size = cast(as_tensor_variable(size, ndim=1), "int64") size = cast(as_tensor_variable(size, ndim=1), "int64")
# This should help ensure that the length of `size` will be available
# after certain types of cloning (e.g. the kind `Scan` performs)
size = specify_shape(size, (get_vector_length(size),))
assert size.dtype in int_dtypes assert size.dtype in int_dtypes
......
...@@ -7,6 +7,7 @@ from aesara.assert_op import Assert ...@@ -7,6 +7,7 @@ from aesara.assert_op import Assert
from aesara.gradient import NullTypeGradError, grad from aesara.gradient import NullTypeGradError, grad
from aesara.tensor.math import eq from aesara.tensor.math import eq
from aesara.tensor.random.op import RandomVariable, default_shape_from_params from aesara.tensor.random.op import RandomVariable, default_shape_from_params
from aesara.tensor.shape import specify_shape
from aesara.tensor.type import all_dtypes, iscalar, tensor from aesara.tensor.type import all_dtypes, iscalar, tensor
...@@ -139,6 +140,27 @@ def test_RandomVariable_bcast(): ...@@ -139,6 +140,27 @@ def test_RandomVariable_bcast():
assert res.broadcastable == (True,) assert res.broadcastable == (True,)
def test_RandomVariable_bcast_specify_shape():
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
s1 = aet.as_tensor(1, dtype=np.int64)
s2 = iscalar()
s2.tag.test_value = 2
s3 = iscalar()
s3.tag.test_value = 3
s3 = Assert("testing")(s3, eq(s1, 1))
size = specify_shape(aet.as_tensor([s1, s3, s2, s2, s1]), (5,))
mu = tensor(config.floatX, [False, False, True])
mu.tag.test_value = np.random.normal(size=(2, 2, 1)).astype(config.floatX)
std = tensor(config.floatX, [False, True, True])
std.tag.test_value = np.ones((2, 1, 1)).astype(config.floatX)
res = rv(mu, std, size=size)
assert res.broadcastable == (True, False, False, False, True)
def test_RandomVariable_floatX(): def test_RandomVariable_floatX():
test_rv_op = RandomVariable( test_rv_op = RandomVariable(
"normal", "normal",
......
...@@ -23,6 +23,7 @@ from aesara.tensor.random.opt import ( ...@@ -23,6 +23,7 @@ from aesara.tensor.random.opt import (
local_rv_size_lift, local_rv_size_lift,
local_subtensor_rv_lift, local_subtensor_rv_lift,
) )
from aesara.tensor.shape import SpecifyShape
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.type import iscalar, vector from aesara.tensor.type import iscalar, vector
...@@ -81,8 +82,11 @@ def test_inplace_optimization(): ...@@ -81,8 +82,11 @@ def test_inplace_optimization():
assert new_out.owner.op.inplace is True assert new_out.owner.op.inplace is True
assert all( assert all(
np.array_equal(a.data, b.data) np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[1:], out.owner.inputs[1:]) for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
) )
# A `SpecifyShape` is added
assert isinstance(new_out.owner.inputs[1].owner.op, SpecifyShape)
assert new_out.owner.inputs[1].owner.inputs[0].equals(out.owner.inputs[1])
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论