提交 117b40c9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a test for a shape inference issue between Scan and RandomVariable

上级 3d4ef668
...@@ -5036,3 +5036,76 @@ def test_mintap_onestep(): ...@@ -5036,3 +5036,76 @@ def test_mintap_onestep():
_seq = np.arange(20).astype("int32") _seq = np.arange(20).astype("int32")
_sum = f(_seq) _sum = f(_seq)
assert _sum == 2 assert _sum == 2
def test_inner_get_vector_length():
"""Make sure we can handle/preserve fixed shape terms when cloning the body of a `Scan`."""
rng_at = RandomStream()
s1 = lscalar("s1")
s2 = lscalar("s2")
size_at = aet.as_tensor([s1, s2])
def scan_body(size):
# `size` will be cloned and replaced with an ownerless `TensorVariable`.
# This will cause `RandomVariable.infer_shape` to fail, because it expects
# `get_vector_length` to work on all `size` arguments.
return rng_at.normal(0, 1, size=size)
res, _ = scan(
scan_body,
non_sequences=[size_at],
n_steps=10,
strict=True,
)
assert isinstance(res.owner.op, Scan)
# Make sure the `size` in `scan_body` is a plain `Variable` instance
# carrying no information with which we can derive its length
size_clone = res.owner.op.inputs[1]
assert size_clone.owner is None
# Make sure the cloned `size` maps to the original `size_at`
inner_outer_map = res.owner.op.get_oinp_iinp_iout_oout_mappings()
outer_input_idx = inner_outer_map["outer_inp_from_inner_inp"][1]
original_size = res.owner.inputs[outer_input_idx]
assert original_size == size_at
with config.change_flags(on_opt_error="raise", on_shape_error="raise"):
res_fn = function([size_at], res.shape)
assert np.array_equal(res_fn((1, 2)), (10, 1, 2))
# Second case has an empty size non-sequence
size_at = aet.as_tensor([], dtype=np.int64)
res, _ = scan(
scan_body,
non_sequences=[size_at],
n_steps=10,
strict=True,
)
assert isinstance(res.owner.op, Scan)
with config.change_flags(on_opt_error="raise", on_shape_error="raise"):
res_fn = function([], res.shape)
assert np.array_equal(res_fn(), (10,))
# Third case has a constant size non-sequence
size_at = aet.as_tensor([3], dtype=np.int64)
res, _ = scan(
scan_body,
non_sequences=[size_at],
n_steps=10,
strict=True,
)
assert isinstance(res.owner.op, Scan)
with config.change_flags(on_opt_error="raise", on_shape_error="raise"):
res_fn = function([], res.shape)
assert np.array_equal(res_fn(), (10, 3))
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论