提交 cc054868 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Apply casting in as_tensor_variable in normalize_size_param

This allows PyTensor to infer more broadcastable patterns, by placing the casting inside the MakeVector Op
上级 da5281be
...@@ -134,7 +134,7 @@ def normalize_size_param( ...@@ -134,7 +134,7 @@ def normalize_size_param(
"Parameter size must be None, an integer, or a sequence with integers." "Parameter size must be None, an integer, or a sequence with integers."
) )
else: else:
size = cast(as_tensor_variable(size, ndim=1), "int64") size = cast(as_tensor_variable(size, ndim=1, dtype="int64"), "int64")
if not isinstance(size, Constant): if not isinstance(size, Constant):
# This should help ensure that the length of non-constant `size`s # This should help ensure that the length of non-constant `size`s
......
...@@ -148,6 +148,9 @@ def test_RandomVariable_bcast(): ...@@ -148,6 +148,9 @@ def test_RandomVariable_bcast():
res = rv(0, 1, size=at.as_tensor(1, dtype=np.int64)) res = rv(0, 1, size=at.as_tensor(1, dtype=np.int64))
assert res.broadcastable == (True,) assert res.broadcastable == (True,)
res = rv(0, 1, size=(at.as_tensor(1, dtype=np.int32), s3))
assert res.broadcastable == (True, False)
def test_RandomVariable_bcast_specify_shape(): def test_RandomVariable_bcast_specify_shape():
rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True) rv = RandomVariable("normal", 0, [0, 0], config.floatX, inplace=True)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论