提交 69037dbb authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Make non-strict zip strict in tensor/random/utils

上级 e200cb5c
...@@ -141,7 +141,7 @@ def explicit_expand_dims( ...@@ -141,7 +141,7 @@ def explicit_expand_dims(
batch_dims = [ batch_dims = [
param.type.ndim - ndim_param param.type.ndim - ndim_param
for param, ndim_param in zip(params, ndim_params, strict=False) for param, ndim_param in zip(params, ndim_params, strict=True)
] ]
if size_length is not None: if size_length is not None:
......
...@@ -74,16 +74,16 @@ def test_RandomVariable_basics(strict_test_value_flags): ...@@ -74,16 +74,16 @@ def test_RandomVariable_basics(strict_test_value_flags):
# `dtype` is respected # `dtype` is respected
rv = RandomVariable("normal", signature="(),()->()", dtype="int32") rv = RandomVariable("normal", signature="(),()->()", dtype="int32")
with config.change_flags(compute_test_value="off"): with config.change_flags(compute_test_value="off"):
rv_out = rv() rv_out = rv(0, 0)
assert rv_out.dtype == "int32" assert rv_out.dtype == "int32"
rv_out = rv(dtype="int64") rv_out = rv(0, 0, dtype="int64")
assert rv_out.dtype == "int64" assert rv_out.dtype == "int64"
with pytest.raises( with pytest.raises(
ValueError, ValueError,
match="Cannot change the dtype of a normal RV from int32 to float32", match="Cannot change the dtype of a normal RV from int32 to float32",
): ):
assert rv(dtype="float32").dtype == "float32" assert rv(0, 0, dtype="float32").dtype == "float32"
def test_RandomVariable_bcast(strict_test_value_flags): def test_RandomVariable_bcast(strict_test_value_flags):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论