提交 528b8d4b authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Specialized C-impl for vector AdvancedIncSubtensor1

Also add checks for runtime broadcast
上级 4311f893
......@@ -67,6 +67,9 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
if len(indices) == 1:
indices = indices[0]
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
return jax_fn(x, indices, y)
return incsubtensor
......
......@@ -287,11 +287,11 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc
x, vals, idxs = node.inputs
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here
broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
broadcast_with_index = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
# TODO: Add runtime_broadcast check
if set_instead_of_inc:
if broadcast:
if broadcast_with_index:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
......@@ -318,7 +318,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
x[idx] = val
return x
else:
if broadcast:
if broadcast_with_index:
@numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs):
......
......@@ -109,6 +109,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
def adv_set_subtensor(x, y, *indices):
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
if not inplace:
x = x.clone()
x[indices] = y.type_as(x)
......@@ -120,6 +122,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
def adv_inc_subtensor_no_duplicates(x, y, *indices):
check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
if not inplace:
x = x.clone()
x[indices] += y.type_as(x)
......
......@@ -2262,6 +2262,12 @@ class AdvancedIncSubtensor1(COp):
check_input = False
params_type = ParamsType(inplace=ps.bool, set_instead_of_inc=ps.bool)
_runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. "
"AdvancedIncSubtensor1 was asked to broadcast the second input (y) along a dimension that was not marked as broadcastable. "
"If broadcasting was intended, use `specify_broadcastable` on the relevant dimension(s)."
)
def __init__(self, inplace=False, set_instead_of_inc=False):
self.inplace = bool(inplace)
self.set_instead_of_inc = bool(set_instead_of_inc)
......@@ -2333,6 +2339,9 @@ class AdvancedIncSubtensor1(COp):
NPY_ARRAY_ENSURECOPY, NULL)"""
def c_support_code(self, **kwargs):
if numpy_version < "1.8.0" or using_numpy_2:
return None
types = [
"npy_" + t
for t in [
......@@ -2523,15 +2532,117 @@ class AdvancedIncSubtensor1(COp):
return code
def c_code(self, node, name, input_names, output_names, sub):
if numpy_version < "1.8.0" or using_numpy_2:
raise NotImplementedError
x, y, idx = input_names
out = output_names[0]
[out] = output_names
copy_of_x = self.copy_of_x(x)
params = sub["params"]
fail = sub["fail"]
x_, y_, idx_ = node.inputs
y_cdtype = y_.type.dtype_specs()[1]
idx_cdtype = idx_.type.dtype_specs()[1]
out_cdtype = node.outputs[0].type.dtype_specs()[1]
y_bcast = y_.type.broadcastable != idx_.type.broadcastable
if (
x_.type.ndim == 1
and y_.type.ndim == 1
and not y_bcast
and x_.type.dtype not in complex_dtypes
and y_.type.dtype not in complex_dtypes
):
# Simple implementation for vector x, y cases
idx_may_be_neg = not (isinstance(idx_, Constant) and idx_.data.min() >= 0)
idx_may_be_invalid = AdvancedSubtensor1._idx_may_be_invalid(x_, idx_)
shape0 = x_.type.shape[0]
# This is used to make sure that when we trust the indices to be valid
# we are not fooled by a wrong static shape
# We mention x to the user in error messages but we work (and make checks) on out,
# which should be x or a copy of it
unexpected_shape0 = (
f"PyArray_SHAPE({out})[0] != {shape0}" if shape0 is not None else "0"
)
op = "=" if self.set_instead_of_inc else "+="
code = f"""
if ({params}->inplace)
{{
if ({x} != {out})
{{
Py_XDECREF({out});
Py_INCREF({x});
{out} = {x};
}}
}}
else
{{
Py_XDECREF({out});
{out} = {copy_of_x};
if (!{out}) {{
// Exception already set
{fail}
}}
}}
if (PyArray_NDIM({out}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) ndim should be 1, got %d", PyArray_NDIM({out}));
{fail}
}}
if ({unexpected_shape0}) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) shape should be {shape0}, got %d", PyArray_SHAPE({out})[0]);
{fail}
}}
if (PyArray_NDIM({idx}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: indices ndim should be 1, got %d", PyArray_NDIM({idx}));
{fail}
}}
if (PyArray_NDIM({y}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: second input (y) ndim should be 1, got %d", PyArray_NDIM({y}));
{fail}
}}
if (PyArray_SHAPE({y})[0] != PyArray_SHAPE({idx})[0]) {{
if ((PyArray_NDIM({y}) == 1) && (PyArray_SHAPE({y})[0] == 1)){{
PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
}} else {{
PyErr_Format(PyExc_ValueError,
"AdvancedIncSubtensor1: Shapes of second input (y) and indices do not match: %d, %d",
PyArray_SHAPE({y})[0], PyArray_SHAPE({idx})[0]);
}}
{fail}
}}
{{
npy_intp out_shape0 = PyArray_SHAPE({out})[0];
{out_cdtype}* out_data = ({out_cdtype}*)PyArray_DATA({out});
{y_cdtype}* y_data = ({y_cdtype}*)PyArray_DATA({y});
{idx_cdtype}* idx_data = ({idx_cdtype}*)PyArray_DATA({idx});
npy_intp n = PyArray_SHAPE({idx})[0];
npy_intp out_jump = PyArray_STRIDES({out})[0] / PyArray_ITEMSIZE({out});
npy_intp y_jump = PyArray_STRIDES({y})[0] / PyArray_ITEMSIZE({y});
npy_intp idx_jump = PyArray_STRIDES({idx})[0] / PyArray_ITEMSIZE({idx});
for(int i = 0; i < n; i++){{
{idx_cdtype} idx = idx_data[i * idx_jump];
if ({int(idx_may_be_neg)}){{
if (idx < 0) {{
idx += out_shape0;
}}
}}
if ({int(idx_may_be_invalid)}){{
if ((idx < 0) || (idx >= out_shape0)) {{
PyErr_Format(PyExc_IndexError,"index %d out of bounds for array with shape %d", idx_data[i * idx_jump], out_shape0);
{fail}
}}
}}
out_data[idx * out_jump] {op} y_data[i * y_jump];
}}
}}
"""
return code
if numpy_version < "1.8.0" or using_numpy_2:
raise NotImplementedError
return f"""
PyObject* rval = NULL;
if ({params}->inplace)
......@@ -2559,14 +2670,37 @@ class AdvancedIncSubtensor1(COp):
"""
def c_code_cache_version(self):
return (8,)
return (9,)
def _check_runtime_broadcasting(
self, node: Apply, x: np.ndarray, y: np.ndarray, idx: np.ndarray
) -> None:
if y.ndim > 0:
y_pt_bcast = node.inputs[1].broadcastable # type: ignore
if not y_pt_bcast[0] and y.shape[0] == 1 and y.shape[0] != idx.shape[0]:
# Attempting to broadcast with index
raise ValueError(self._runtime_broadcast_error_msg)
if any(
not y_bcast and y_dim == 1 and y_dim != x_dim
for y_bcast, y_dim, x_dim in zip(
reversed(y_pt_bcast),
reversed(y.shape),
reversed(x.shape),
strict=False,
)
):
# Attempting to broadcast with buffer
raise ValueError(self._runtime_broadcast_error_msg)
def perform(self, node, inputs, output_storage):
x, y, idx = inputs
def perform(self, node, inp, out_):
x, y, idx = inp
(out,) = out_
if not self.inplace:
x = x.copy()
self._check_runtime_broadcasting(node, x, y, idx)
if self.set_instead_of_inc:
x[idx] = y
else:
......@@ -2574,7 +2708,7 @@ class AdvancedIncSubtensor1(COp):
# many times: it does it only once.
np.add.at(x, idx, y)
out[0] = x
output_storage[0][0] = x
def infer_shape(self, fgraph, node, ishapes):
x, y, ilist = ishapes
......
......@@ -5,6 +5,7 @@ from io import StringIO
import numpy as np
import pytest
from numpy.testing import assert_array_equal
from packaging import version
import pytensor
import pytensor.scalar as scal
......@@ -26,7 +27,7 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf, lt, switch
from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_shape
from pytensor.tensor.shape import specify_broadcastable, specify_shape
from pytensor.tensor.subtensor import (
AdvancedIncSubtensor,
AdvancedIncSubtensor1,
......@@ -1101,9 +1102,9 @@ class TestSubtensor(utt.OptimizationTestMixin):
n = self.shared(data)
for idx in idxs:
# Should stay on the cpu.
idx_ = shared(np.asarray(idx))
t = n[idx_]
idx_np = np.asarray(idx)
idx_pt = shared(idx_np, shape=(1 if idx_np.shape[0] == 1 else None,))
t = n[idx_pt]
gn = pytensor.grad(pt_sum(exp(t)), n)
f = self.function([], [gn, gn.shape], op=AdvancedIncSubtensor1)
topo = f.maker.fgraph.toposort()
......@@ -1126,13 +1127,13 @@ class TestSubtensor(utt.OptimizationTestMixin):
assert np.allclose(gshape, data.shape)
def fct(t):
return pt_sum(t[idx_])
return pt_sum(t[idx_pt])
utt.verify_grad(fct, [data], mode=self.mode)
# Test the grad of the grad (e.i. AdvancedIncSubtensor1.grad)
def fct2(t):
return pytensor.grad(pt_sum(t[idx_]), t)
return pytensor.grad(pt_sum(t[idx_pt]), t)
utt.verify_grad(fct2, [data], mode=self.mode)
......@@ -1143,7 +1144,9 @@ class TestSubtensor(utt.OptimizationTestMixin):
ops = subtensor_ops
if idx is idxs[0]:
# TODO FIXME: This is a very poorly specified test.
f = self.function([], [gn.shape, n[idx_].shape], op=ops, N=0, N_fast=0)
f = self.function(
[], [gn.shape, n[idx_pt].shape], op=ops, N=0, N_fast=0
)
f()
def test_wrong_exception_regression(self):
......@@ -1231,10 +1234,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
data_num_init = np.arange(data_size, dtype=self.dtype)
data_num_init = data_num_init.reshape(data_shape)
inc_shapes = [data_shape[i:] for i in range(0, len(data_shape) + 1)]
# Test broadcasting of y.
inc_shapes += [(1,) + inc_shapes[-1][1:]]
for inc_shape in inc_shapes:
inc_n_dims = len(inc_shape)
# We copy the numeric value to be 100% sure there is no
# risk of accidentally sharing it.
data_num = data_num_init.copy()
......@@ -1263,10 +1263,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
replace=(not set_instead_of_inc),
)
idx_num = idx_num.astype("int64")
# Symbolic variable with increment value.
inc_var = TensorType(
shape=(None,) * inc_n_dims, dtype=self.dtype
)()
# Trick for the case where `inc_shape` is the same as
# `data_shape`: what we actually want is the first
# shape element to be equal to the number of rows to
......@@ -1275,6 +1272,15 @@ class TestSubtensor(utt.OptimizationTestMixin):
len(inc_shapes) == 0 or inc_shape[0] != 1
):
inc_shape = (n_to_inc,) + inc_shape[1:]
# Symbolic variable with increment value.
inc_var_static_shape = tuple(
1 if dim_length == 1 else None for dim_length in inc_shape
)
inc_var = TensorType(
shape=inc_var_static_shape, dtype=self.dtype
)()
# The param dtype is needed when inc_shape is empty.
# By default, it would return a float and rng.uniform
# with NumPy 1.10 will raise a Deprecation warning.
......@@ -1341,6 +1347,31 @@ class TestSubtensor(utt.OptimizationTestMixin):
# you enable the debug code above.
assert np.allclose(f_out, output_num), (params, f_out, output_num)
@pytest.mark.skipif(
version.parse(np.__version__) < version.parse("2.0"),
reason="Legacy C-implementation did not check for runtime broadcast",
)
@pytest.mark.parametrize("func", (advanced_inc_subtensor1, advanced_set_subtensor1))
def test_advanced1_inc_runtime_broadcast(self, func):
y = matrix("y", dtype="float64", shape=(None, None))
x = ptb.zeros((10, 5))
idxs = np.repeat(np.arange(10), 2)
out = func(x, y, idxs)
f = function([y], out)
f(np.ones((20, 5))) # Fine
with pytest.raises(
ValueError,
match="Runtime broadcasting not allowed. AdvancedIncSubtensor1 was asked",
):
f(np.ones((1, 5)))
with pytest.raises(
ValueError,
match="Runtime broadcasting not allowed. AdvancedIncSubtensor1 was asked",
):
f(np.ones((20, 1)))
def test_adv_constant_arg(self):
# Test case provided (and bug detected, gh-607) by John Salvatier
m = matrix("m")
......@@ -2398,7 +2429,11 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [2, 3]
self._compile_and_check(
[admat, bdmat],
[advanced_set_subtensor1(admat, bdmat, aivec_val)],
[
advanced_set_subtensor1(
admat, specify_broadcastable(bdmat, 0), aivec_val
)
],
[admat_val, [[1, 2, 3, 4]]],
AdvancedIncSubtensor1,
)
......@@ -2425,7 +2460,11 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [2, 3]
self._compile_and_check(
[adtens4, bdtens4],
[advanced_set_subtensor1(adtens4, bdtens4, aivec_val)],
[
advanced_set_subtensor1(
adtens4, specify_broadcastable(bdtens4, 0, 1, 2), aivec_val
)
],
[adtens4_val, [[[[1, 2, 3, 4, 5]]]]],
AdvancedIncSubtensor1,
warn=False,
......@@ -2476,7 +2515,11 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [2, 3]
self._compile_and_check(
[adtens4, bdtens4],
[advanced_set_subtensor1(adtens4, bdtens4, aivec_val)],
[
advanced_set_subtensor1(
adtens4, specify_broadcastable(bdtens4, 1, 2), aivec_val
)
],
[adtens4_val, [[[[1, 2, 3, 4, 5]]], [[[6, 7, 8, 9, 10]]]]],
AdvancedIncSubtensor1,
warn=False,
......@@ -3028,3 +3071,29 @@ class TestBenchmarks:
)
fn.vm.allow_gc = gc
benchmark(fn, x_values, idxs_values)
@pytest.mark.parametrize(
"static_shape", (False, True), ids=lambda x: f"static_shape={x}"
)
@pytest.mark.parametrize("gc", (False, True), ids=lambda x: f"gc={x}")
@pytest.mark.parametrize("func", (inc_subtensor, set_subtensor))
def test_advanced_incsubtensor1(self, func, static_shape, gc, benchmark):
x = vector("x", shape=(85 if static_shape else None,))
x_values = np.zeros((85,))
buffer = ptb.zeros_like(x)
y_values = np.random.normal(size=(85 * 11,))
idxs_values = np.arange(85).repeat(11)
# With static shape and constant indices we know all idxs are valid
# Reuse same buffer of zeros, to check we rather allocate twice than copy inside IncSubtensor
out1 = func(buffer[idxs_values], y_values)
out2 = func(buffer[idxs_values[::-1]], y_values)
fn = pytensor.function(
[x],
[pytensor.Out(out1, borrow=True), pytensor.Out(out2, borrow=True)],
on_unused_input="ignore",
trust_input=True,
)
fn.vm.allow_gc = gc
benchmark(fn, x_values)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论