提交 4829455b authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Jesse Grabowski

Fix bug in local_blockwise_advanced_inc_subtensor

上级 709f745c
......@@ -25,6 +25,7 @@ from pytensor.tensor.basic import (
alloc,
cast,
concatenate,
expand_dims,
get_scalar_constant_value,
get_underlying_scalar_constant_value,
register_infer_shape,
......@@ -1576,7 +1577,15 @@ def local_blockwise_advanced_inc_subtensor(fgraph, node):
x = alloc(x, *batch_shape, *core_shape)
new_idxs = [slice(None)] * batch_ndim + new_idxs
symbolic_idxs = x[tuple(new_idxs)].owner.inputs[1:]
x_view = x[tuple(new_idxs)]
# We need to introduce any implicit expand_dims on core dimension of y
y_core_ndim = y.type.ndim - batch_ndim
if (missing_y_core_ndim := x_view.type.ndim - batch_ndim - y_core_ndim) > 0:
missing_axes = tuple(range(batch_ndim, batch_ndim + missing_y_core_ndim))
y = expand_dims(y, missing_axes)
symbolic_idxs = x_view.owner.inputs[1:]
new_out = op.core_op.make_node(x, y, *symbolic_idxs).outputs
copy_stack_trace(node.outputs, new_out)
return new_out
......
......@@ -1788,10 +1788,24 @@ def test_local_uint_constant_indices():
assert new_index.type.dtype == "uint8"
@pytest.mark.parametrize("core_y_implicitly_batched", (False, True))
@pytest.mark.parametrize("set_instead_of_inc", (True, False))
def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
def test_local_blockwise_advanced_inc_subtensor(
set_instead_of_inc, core_y_implicitly_batched
):
rng = np.random.default_rng([1764, set_instead_of_inc, core_y_implicitly_batched])
def np_inplace_f(x, idx, y):
if core_y_implicitly_batched:
y = y[..., None]
if set_instead_of_inc:
x[idx] = y
else:
x[idx] += y
core_y_shape = () if core_y_implicitly_batched else (3,)
core_x = tensor("x", shape=(6,))
core_y = tensor("y", shape=(3,))
core_y = tensor("y", shape=core_y_shape, dtype=int)
core_idxs = [0, 2, 4]
if set_instead_of_inc:
core_graph = set_subtensor(core_x[core_idxs], core_y)
......@@ -1800,7 +1814,7 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
# Only x is batched
x = tensor("x", shape=(5, 2, 6))
y = tensor("y", shape=(3,))
y = tensor("y", shape=core_y_shape, dtype=int)
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
......@@ -1810,17 +1824,14 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([5, 6, 7]).astype(dtype=core_y.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
expected_out = test_x.copy()
if set_instead_of_inc:
expected_out[:, :, core_idxs] = test_y
else:
expected_out[:, :, core_idxs] += test_y
np_inplace_f(expected_out, np.s_[:, :, core_idxs], test_y)
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
# Only y is batched
x = tensor("y", shape=(6,))
y = tensor("y", shape=(2, 3))
y = tensor("y", shape=(2, *core_y_shape), dtype=int)
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
......@@ -1830,17 +1841,14 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([[3, 3, 3], [5, 6, 7]]).astype(dtype=core_y.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
expected_out = np.ones((2, *x.type.shape))
if set_instead_of_inc:
expected_out[:, core_idxs] = test_y
else:
expected_out[:, core_idxs] += test_y
np_inplace_f(expected_out, np.s_[:, core_idxs], test_y)
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
# Both x and y are batched, and do not need to be broadcasted
x = tensor("y", shape=(2, 6))
y = tensor("y", shape=(2, 3))
y = tensor("y", shape=(2, *core_y_shape), dtype=int)
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
......@@ -1850,17 +1858,14 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([[5, 6, 7], [3, 3, 3]]).astype(dtype=core_y.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
expected_out = test_x.copy()
if set_instead_of_inc:
expected_out[:, core_idxs] = test_y
else:
expected_out[:, core_idxs] += test_y
np_inplace_f(expected_out, np.s_[:, core_idxs], test_y)
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
# Both x and y are batched, but must be broadcasted
x = tensor("y", shape=(5, 1, 6))
y = tensor("y", shape=(1, 2, 3))
y = tensor("y", shape=(1, 2, *core_y_shape), dtype=int)
out = vectorize_graph(core_graph, replace={core_x: x, core_y: y})
assert isinstance(out.owner.op, Blockwise)
......@@ -1870,16 +1875,13 @@ def test_local_blockwise_advanced_inc_subtensor(set_instead_of_inc):
)
test_x = np.ones(x.type.shape, dtype=x.type.dtype)
test_y = np.array([[[5, 6, 7], [3, 3, 3]]]).astype(dtype=core_y.type.dtype)
test_y = rng.integers(1, 10, size=y.type.shape, dtype=y.type.dtype)
final_shape = (
*np.broadcast_shapes(x.type.shape[:-1], y.type.shape[:-1]),
*np.broadcast_shapes(x.type.shape[:2], y.type.shape[:2]),
x.type.shape[-1],
)
expected_out = np.broadcast_to(test_x, final_shape).copy()
if set_instead_of_inc:
expected_out[:, :, core_idxs] = test_y
else:
expected_out[:, :, core_idxs] += test_y
np_inplace_f(expected_out, np.s_[:, :, core_idxs], test_y)
np.testing.assert_allclose(fn(test_x, test_y), expected_out)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论