提交 cc054868 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Apply casting in as_tensor_variable in normalize_size_param

This allows PyTensor to infer more broadcastable patterns, by placing the casting inside the MakeVector Op
上级 da5281be
......@@ -134,7 +134,7 @@ def normalize_size_param(
"Parameter size must be None, an integer, or a sequence with integers."
)
else:
size = cast(as_tensor_variable(size, ndim=1), "int64")
size = cast(as_tensor_variable(size, ndim=1, dtype="int64"), "int64")
if not isinstance(size, Constant):
# This should help ensure that the length of non-constant `size`s
......
......@@ -148,6 +148,9 @@ def test_RandomVariable_bcast():
res = rv(0, 1, size=at.as_tensor(1, dtype=np.int64))
assert res.broadcastable == (True,)
res = rv(0, 1, size=(at.as_tensor(1, dtype=np.int32), s3))
assert res.broadcastable == (True, False)
def test_RandomVariable_bcast_specify_shape():
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论