提交 b3f686f7 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Do not add SpecifyShape to constant sizes in RandomVariable

上级 ed759a09
......@@ -5,7 +5,7 @@ from itertools import zip_longest
import numpy as np
from aesara.compile.sharedvalue import shared
from aesara.graph.basic import Variable
from aesara.graph.basic import Constant, Variable
from aesara.tensor import get_vector_length
from aesara.tensor.basic import as_tensor_variable, cast, constant
from aesara.tensor.extra_ops import broadcast_to
......@@ -123,8 +123,11 @@ def normalize_size_param(size):
)
else:
size = cast(as_tensor_variable(size, ndim=1), "int64")
# This should help ensure that the length of `size` will be available
# after certain types of cloning (e.g. the kind `Scan` performs)
if not isinstance(size, Constant):
# This should help ensure that the length of non-constant `size`s
# will be available after certain types of cloning (e.g. the kind
# `Scan` performs)
size = specify_shape(size, (get_vector_length(size),))
assert size.dtype in int_dtypes
......
......@@ -23,7 +23,6 @@ from aesara.tensor.random.opt import (
local_rv_size_lift,
local_subtensor_rv_lift,
)
from aesara.tensor.shape import SpecifyShape
from aesara.tensor.subtensor import AdvancedSubtensor, AdvancedSubtensor1, Subtensor
from aesara.tensor.type import iscalar, vector
......@@ -84,9 +83,7 @@ def test_inplace_optimization():
np.array_equal(a.data, b.data)
for a, b in zip(new_out.owner.inputs[2:], out.owner.inputs[2:])
)
# A `SpecifyShape` is added
assert isinstance(new_out.owner.inputs[1].owner.op, SpecifyShape)
assert new_out.owner.inputs[1].owner.inputs[0].equals(out.owner.inputs[1])
assert np.array_equal(new_out.owner.inputs[1].data, [])
@config.change_flags(compute_test_value="raise")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论