提交 69037dbb authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Make non-strict zip strict in tensor/random/utils

上级 e200cb5c
......@@ -141,7 +141,7 @@ def explicit_expand_dims(
batch_dims = [
param.type.ndim - ndim_param
for param, ndim_param in zip(params, ndim_params, strict=False)
for param, ndim_param in zip(params, ndim_params, strict=True)
]
if size_length is not None:
......
......@@ -74,16 +74,16 @@ def test_RandomVariable_basics(strict_test_value_flags):
# `dtype` is respected
rv = RandomVariable("normal", signature="(),()->()", dtype="int32")
with config.change_flags(compute_test_value="off"):
rv_out = rv()
rv_out = rv(0, 0)
assert rv_out.dtype == "int32"
rv_out = rv(dtype="int64")
rv_out = rv(0, 0, dtype="int64")
assert rv_out.dtype == "int64"
with pytest.raises(
ValueError,
match="Cannot change the dtype of a normal RV from int32 to float32",
):
assert rv(dtype="float32").dtype == "float32"
assert rv(0, 0, dtype="float32").dtype == "float32"
def test_RandomVariable_bcast(strict_test_value_flags):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论