提交 b3f686f7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Do not add SpecifyShape to constant sizes in RandomVariable

上级 ed759a09
...@@ -5,7 +5,7 @@ from itertools import zip_longest ...@@ -5,7 +5,7 @@ from itertools import zip_longest
import numpy as np import numpy as np
from aesara.compile.sharedvalue import shared from aesara.compile.sharedvalue import shared
from aesara.graph.basic import Variable from aesara.graph.basic import Constant, Variable
from aesara.tensor import get_vector_length from aesara.tensor import get_vector_length
from aesara.tensor.basic import as_tensor_variable, cast, constant from aesara.tensor.basic import as_tensor_variable, cast, constant
from aesara.tensor.extra_ops import broadcast_to from aesara.tensor.extra_ops import broadcast_to
...@@ -123,8 +123,11 @@ def normalize_size_param(size): ...@@ -123,8 +123,11 @@ def normalize_size_param(size):
) )
else: else:
size = cast(as_tensor_variable(size, ndim=1), "int64") size = cast(as_tensor_variable(size, ndim=1), "int64")
# This should help ensure that the length of `size` will be available
# after certain types of cloning (e.g. the kind `Scan` performs) if not isinstance(size, Constant):
# This should help ensure that the length of non-constant `size`s
# will be available after certain types of cloning (e.g. the kind
# `Scan` performs)
size = specify_shape(size, (get_vector_length(size),)) size = specify_shape(size, (get_vector_length(size),))
assert size.dtype in int_dtypes assert size.dtype in int_dtypes
......
...@@ -23,7 +23,6 @@ from aesara.tensor.random.opt import ( ...@@ -23,7 +23,6 @@ from aesara.tensor.random.opt import (
local_rv_size_lift, local_rv_size_lift,
local_subtensor_rv_lift, local_subtensor_rv_lift,
) )
from aesara.tensor.shape import SpecifyShape
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.type import iscalar, vector from aesara.tensor.type import iscalar, vector
...@@ -84,9 +83,7 @@ def test_inplace_optimization(): ...@@ -84,9 +83,7 @@ def test_inplace_optimization():
np.array_equal(a.data, b.data) np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:]) for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
) )
# A `SpecifyShape` is added assert np.array_equal(new_out.owner.inputs[1].data, [])
assert isinstance(new_out.owner.inputs[1].owner.op, SpecifyShape)
assert new_out.owner.inputs[1].owner.inputs[0].equals(out.owner.inputs[1])
@config.change_flags(compute_test_value="raise") @config.change_flags(compute_test_value="raise")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论