提交 528b8d4b authored 作者: ricardoV94's avatar ricardoV94 提交者: Ricardo Vieira

Specialized C-impl for vector AdvancedIncSubtensor1

Also add checks for runtime broadcast
上级 4311f893
...@@ -67,6 +67,9 @@ def jax_funcify_IncSubtensor(op, node, **kwargs): ...@@ -67,6 +67,9 @@ def jax_funcify_IncSubtensor(op, node, **kwargs):
if len(indices) == 1: if len(indices) == 1:
indices = indices[0] indices = indices[0]
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
return jax_fn(x, indices, y) return jax_fn(x, indices, y)
return incsubtensor return incsubtensor
......
...@@ -287,11 +287,11 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -287,11 +287,11 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
inplace = op.inplace inplace = op.inplace
set_instead_of_inc = op.set_instead_of_inc set_instead_of_inc = op.set_instead_of_inc
x, vals, idxs = node.inputs x, vals, idxs = node.inputs
# TODO: Add explicit expand_dims in make_node so we don't need to worry about this here broadcast_with_index = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0]
broadcast = vals.type.ndim < x.type.ndim or vals.type.broadcastable[0] # TODO: Add runtime_broadcast check
if set_instead_of_inc: if set_instead_of_inc:
if broadcast: if broadcast_with_index:
@numba_njit(boundscheck=True) @numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs): def advancedincsubtensor1_inplace(x, val, idxs):
...@@ -318,7 +318,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs): ...@@ -318,7 +318,7 @@ def numba_funcify_AdvancedIncSubtensor1(op, node, **kwargs):
x[idx] = val x[idx] = val
return x return x
else: else:
if broadcast: if broadcast_with_index:
@numba_njit(boundscheck=True) @numba_njit(boundscheck=True)
def advancedincsubtensor1_inplace(x, val, idxs): def advancedincsubtensor1_inplace(x, val, idxs):
......
...@@ -109,6 +109,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): ...@@ -109,6 +109,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
def adv_set_subtensor(x, y, *indices): def adv_set_subtensor(x, y, *indices):
check_negative_steps(indices) check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
if not inplace: if not inplace:
x = x.clone() x = x.clone()
x[indices] = y.type_as(x) x[indices] = y.type_as(x)
...@@ -120,6 +122,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs): ...@@ -120,6 +122,8 @@ def pytorch_funcify_AdvancedIncSubtensor(op, node, **kwargs):
def adv_inc_subtensor_no_duplicates(x, y, *indices): def adv_inc_subtensor_no_duplicates(x, y, *indices):
check_negative_steps(indices) check_negative_steps(indices)
if isinstance(op, AdvancedIncSubtensor1):
op._check_runtime_broadcasting(node, x, y, indices)
if not inplace: if not inplace:
x = x.clone() x = x.clone()
x[indices] += y.type_as(x) x[indices] += y.type_as(x)
......
...@@ -2262,6 +2262,12 @@ class AdvancedIncSubtensor1(COp): ...@@ -2262,6 +2262,12 @@ class AdvancedIncSubtensor1(COp):
check_input = False check_input = False
params_type = ParamsType(inplace=ps.bool, set_instead_of_inc=ps.bool) params_type = ParamsType(inplace=ps.bool, set_instead_of_inc=ps.bool)
_runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. "
"AdvancedIncSubtensor1 was asked to broadcast the second input (y) along a dimension that was not marked as broadcastable. "
"If broadcasting was intended, use `specify_broadcastable` on the relevant dimension(s)."
)
def __init__(self, inplace=False, set_instead_of_inc=False): def __init__(self, inplace=False, set_instead_of_inc=False):
self.inplace = bool(inplace) self.inplace = bool(inplace)
self.set_instead_of_inc = bool(set_instead_of_inc) self.set_instead_of_inc = bool(set_instead_of_inc)
...@@ -2333,6 +2339,9 @@ class AdvancedIncSubtensor1(COp): ...@@ -2333,6 +2339,9 @@ class AdvancedIncSubtensor1(COp):
NPY_ARRAY_ENSURECOPY, NULL)""" NPY_ARRAY_ENSURECOPY, NULL)"""
def c_support_code(self, **kwargs): def c_support_code(self, **kwargs):
if numpy_version < "1.8.0" or using_numpy_2:
return None
types = [ types = [
"npy_" + t "npy_" + t
for t in [ for t in [
...@@ -2523,15 +2532,117 @@ class AdvancedIncSubtensor1(COp): ...@@ -2523,15 +2532,117 @@ class AdvancedIncSubtensor1(COp):
return code return code
def c_code(self, node, name, input_names, output_names, sub): def c_code(self, node, name, input_names, output_names, sub):
if numpy_version < "1.8.0" or using_numpy_2:
raise NotImplementedError
x, y, idx = input_names x, y, idx = input_names
out = output_names[0] [out] = output_names
copy_of_x = self.copy_of_x(x) copy_of_x = self.copy_of_x(x)
params = sub["params"] params = sub["params"]
fail = sub["fail"] fail = sub["fail"]
x_, y_, idx_ = node.inputs
y_cdtype = y_.type.dtype_specs()[1]
idx_cdtype = idx_.type.dtype_specs()[1]
out_cdtype = node.outputs[0].type.dtype_specs()[1]
y_bcast = y_.type.broadcastable != idx_.type.broadcastable
if (
x_.type.ndim == 1
and y_.type.ndim == 1
and not y_bcast
and x_.type.dtype not in complex_dtypes
and y_.type.dtype not in complex_dtypes
):
# Simple implementation for vector x, y cases
idx_may_be_neg = not (isinstance(idx_, Constant) and idx_.data.min() >= 0)
idx_may_be_invalid = AdvancedSubtensor1._idx_may_be_invalid(x_, idx_)
shape0 = x_.type.shape[0]
# This is used to make sure that when we trust the indices to be valid
# we are not fooled by a wrong static shape
# We mention x to the user in error messages but we work (and make checks) on out,
# which should be x or a copy of it
unexpected_shape0 = (
f"PyArray_SHAPE({out})[0] != {shape0}" if shape0 is not None else "0"
)
op = "=" if self.set_instead_of_inc else "+="
code = f"""
if ({params}->inplace)
{{
if ({x} != {out})
{{
Py_XDECREF({out});
Py_INCREF({x});
{out} = {x};
}}
}}
else
{{
Py_XDECREF({out});
{out} = {copy_of_x};
if (!{out}) {{
// Exception already set
{fail}
}}
}}
if (PyArray_NDIM({out}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) ndim should be 1, got %d", PyArray_NDIM({out}));
{fail}
}}
if ({unexpected_shape0}) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: first input (x) shape should be {shape0}, got %d", PyArray_SHAPE({out})[0]);
{fail}
}}
if (PyArray_NDIM({idx}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: indices ndim should be 1, got %d", PyArray_NDIM({idx}));
{fail}
}}
if (PyArray_NDIM({y}) != 1) {{
PyErr_Format(PyExc_ValueError, "AdvancedIncSubtensor1: second input (y) ndim should be 1, got %d", PyArray_NDIM({y}));
{fail}
}}
if (PyArray_SHAPE({y})[0] != PyArray_SHAPE({idx})[0]) {{
if ((PyArray_NDIM({y}) == 1) && (PyArray_SHAPE({y})[0] == 1)){{
PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
}} else {{
PyErr_Format(PyExc_ValueError,
"AdvancedIncSubtensor1: Shapes of second input (y) and indices do not match: %d, %d",
PyArray_SHAPE({y})[0], PyArray_SHAPE({idx})[0]);
}}
{fail}
}}
{{
npy_intp out_shape0 = PyArray_SHAPE({out})[0];
{out_cdtype}* out_data = ({out_cdtype}*)PyArray_DATA({out});
{y_cdtype}* y_data = ({y_cdtype}*)PyArray_DATA({y});
{idx_cdtype}* idx_data = ({idx_cdtype}*)PyArray_DATA({idx});
npy_intp n = PyArray_SHAPE({idx})[0];
npy_intp out_jump = PyArray_STRIDES({out})[0] / PyArray_ITEMSIZE({out});
npy_intp y_jump = PyArray_STRIDES({y})[0] / PyArray_ITEMSIZE({y});
npy_intp idx_jump = PyArray_STRIDES({idx})[0] / PyArray_ITEMSIZE({idx});
for(int i = 0; i < n; i++){{
{idx_cdtype} idx = idx_data[i * idx_jump];
if ({int(idx_may_be_neg)}){{
if (idx < 0) {{
idx += out_shape0;
}}
}}
if ({int(idx_may_be_invalid)}){{
if ((idx < 0) || (idx >= out_shape0)) {{
PyErr_Format(PyExc_IndexError,"index %d out of bounds for array with shape %d", idx_data[i * idx_jump], out_shape0);
{fail}
}}
}}
out_data[idx * out_jump] {op} y_data[i * y_jump];
}}
}}
"""
return code
if numpy_version < "1.8.0" or using_numpy_2:
raise NotImplementedError
return f""" return f"""
PyObject* rval = NULL; PyObject* rval = NULL;
if ({params}->inplace) if ({params}->inplace)
...@@ -2559,14 +2670,37 @@ class AdvancedIncSubtensor1(COp): ...@@ -2559,14 +2670,37 @@ class AdvancedIncSubtensor1(COp):
""" """
def c_code_cache_version(self): def c_code_cache_version(self):
return (8,) return (9,)
def _check_runtime_broadcasting(
self, node: Apply, x: np.ndarray, y: np.ndarray, idx: np.ndarray
) -> None:
if y.ndim > 0:
y_pt_bcast = node.inputs[1].broadcastable # type: ignore
if not y_pt_bcast[0] and y.shape[0] == 1 and y.shape[0] != idx.shape[0]:
# Attempting to broadcast with index
raise ValueError(self._runtime_broadcast_error_msg)
if any(
not y_bcast and y_dim == 1 and y_dim != x_dim
for y_bcast, y_dim, x_dim in zip(
reversed(y_pt_bcast),
reversed(y.shape),
reversed(x.shape),
strict=False,
)
):
# Attempting to broadcast with buffer
raise ValueError(self._runtime_broadcast_error_msg)
def perform(self, node, inputs, output_storage):
x, y, idx = inputs
def perform(self, node, inp, out_):
x, y, idx = inp
(out,) = out_
if not self.inplace: if not self.inplace:
x = x.copy() x = x.copy()
self._check_runtime_broadcasting(node, x, y, idx)
if self.set_instead_of_inc: if self.set_instead_of_inc:
x[idx] = y x[idx] = y
else: else:
...@@ -2574,7 +2708,7 @@ class AdvancedIncSubtensor1(COp): ...@@ -2574,7 +2708,7 @@ class AdvancedIncSubtensor1(COp):
# many times: it does it only once. # many times: it does it only once.
np.add.at(x, idx, y) np.add.at(x, idx, y)
out[0] = x output_storage[0][0] = x
def infer_shape(self, fgraph, node, ishapes): def infer_shape(self, fgraph, node, ishapes):
x, y, ilist = ishapes x, y, ilist = ishapes
......
...@@ -5,6 +5,7 @@ from io import StringIO ...@@ -5,6 +5,7 @@ from io import StringIO
import numpy as np import numpy as np
import pytest import pytest
from numpy.testing import assert_array_equal from numpy.testing import assert_array_equal
from packaging import version
import pytensor import pytensor
import pytensor.scalar as scal import pytensor.scalar as scal
...@@ -26,7 +27,7 @@ from pytensor.tensor.blockwise import Blockwise ...@@ -26,7 +27,7 @@ from pytensor.tensor.blockwise import Blockwise
from pytensor.tensor.elemwise import DimShuffle from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import exp, isinf, lt, switch from pytensor.tensor.math import exp, isinf, lt, switch
from pytensor.tensor.math import sum as pt_sum from pytensor.tensor.math import sum as pt_sum
from pytensor.tensor.shape import specify_shape from pytensor.tensor.shape import specify_broadcastable, specify_shape
from pytensor.tensor.subtensor import ( from pytensor.tensor.subtensor import (
AdvancedIncSubtensor, AdvancedIncSubtensor,
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
...@@ -1101,9 +1102,9 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1101,9 +1102,9 @@ class TestSubtensor(utt.OptimizationTestMixin):
n = self.shared(data) n = self.shared(data)
for idx in idxs: for idx in idxs:
# Should stay on the cpu. idx_np = np.asarray(idx)
idx_ = shared(np.asarray(idx)) idx_pt = shared(idx_np, shape=(1 if idx_np.shape[0] == 1 else None,))
t = n[idx_] t = n[idx_pt]
gn = pytensor.grad(pt_sum(exp(t)), n) gn = pytensor.grad(pt_sum(exp(t)), n)
f = self.function([], [gn, gn.shape], op=AdvancedIncSubtensor1) f = self.function([], [gn, gn.shape], op=AdvancedIncSubtensor1)
topo = f.maker.fgraph.toposort() topo = f.maker.fgraph.toposort()
...@@ -1126,13 +1127,13 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1126,13 +1127,13 @@ class TestSubtensor(utt.OptimizationTestMixin):
assert np.allclose(gshape, data.shape) assert np.allclose(gshape, data.shape)
def fct(t): def fct(t):
return pt_sum(t[idx_]) return pt_sum(t[idx_pt])
utt.verify_grad(fct, [data], mode=self.mode) utt.verify_grad(fct, [data], mode=self.mode)
# Test the grad of the grad (e.i. AdvancedIncSubtensor1.grad) # Test the grad of the grad (e.i. AdvancedIncSubtensor1.grad)
def fct2(t): def fct2(t):
return pytensor.grad(pt_sum(t[idx_]), t) return pytensor.grad(pt_sum(t[idx_pt]), t)
utt.verify_grad(fct2, [data], mode=self.mode) utt.verify_grad(fct2, [data], mode=self.mode)
...@@ -1143,7 +1144,9 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1143,7 +1144,9 @@ class TestSubtensor(utt.OptimizationTestMixin):
ops = subtensor_ops ops = subtensor_ops
if idx is idxs[0]: if idx is idxs[0]:
# TODO FIXME: This is a very poorly specified test. # TODO FIXME: This is a very poorly specified test.
f = self.function([], [gn.shape, n[idx_].shape], op=ops, N=0, N_fast=0) f = self.function(
[], [gn.shape, n[idx_pt].shape], op=ops, N=0, N_fast=0
)
f() f()
def test_wrong_exception_regression(self): def test_wrong_exception_regression(self):
...@@ -1231,10 +1234,7 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1231,10 +1234,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
data_num_init = np.arange(data_size, dtype=self.dtype) data_num_init = np.arange(data_size, dtype=self.dtype)
data_num_init = data_num_init.reshape(data_shape) data_num_init = data_num_init.reshape(data_shape)
inc_shapes = [data_shape[i:] for i in range(0, len(data_shape) + 1)] inc_shapes = [data_shape[i:] for i in range(0, len(data_shape) + 1)]
# Test broadcasting of y.
inc_shapes += [(1,) + inc_shapes[-1][1:]]
for inc_shape in inc_shapes: for inc_shape in inc_shapes:
inc_n_dims = len(inc_shape)
# We copy the numeric value to be 100% sure there is no # We copy the numeric value to be 100% sure there is no
# risk of accidentally sharing it. # risk of accidentally sharing it.
data_num = data_num_init.copy() data_num = data_num_init.copy()
...@@ -1263,10 +1263,7 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1263,10 +1263,7 @@ class TestSubtensor(utt.OptimizationTestMixin):
replace=(not set_instead_of_inc), replace=(not set_instead_of_inc),
) )
idx_num = idx_num.astype("int64") idx_num = idx_num.astype("int64")
# Symbolic variable with increment value.
inc_var = TensorType(
shape=(None,) * inc_n_dims, dtype=self.dtype
)()
# Trick for the case where `inc_shape` is the same as # Trick for the case where `inc_shape` is the same as
# `data_shape`: what we actually want is the first # `data_shape`: what we actually want is the first
# shape element to be equal to the number of rows to # shape element to be equal to the number of rows to
...@@ -1275,6 +1272,15 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1275,6 +1272,15 @@ class TestSubtensor(utt.OptimizationTestMixin):
len(inc_shapes) == 0 or inc_shape[0] != 1 len(inc_shapes) == 0 or inc_shape[0] != 1
): ):
inc_shape = (n_to_inc,) + inc_shape[1:] inc_shape = (n_to_inc,) + inc_shape[1:]
# Symbolic variable with increment value.
inc_var_static_shape = tuple(
1 if dim_length == 1 else None for dim_length in inc_shape
)
inc_var = TensorType(
shape=inc_var_static_shape, dtype=self.dtype
)()
# The param dtype is needed when inc_shape is empty. # The param dtype is needed when inc_shape is empty.
# By default, it would return a float and rng.uniform # By default, it would return a float and rng.uniform
# with NumPy 1.10 will raise a Deprecation warning. # with NumPy 1.10 will raise a Deprecation warning.
...@@ -1341,6 +1347,31 @@ class TestSubtensor(utt.OptimizationTestMixin): ...@@ -1341,6 +1347,31 @@ class TestSubtensor(utt.OptimizationTestMixin):
# you enable the debug code above. # you enable the debug code above.
assert np.allclose(f_out, output_num), (params, f_out, output_num) assert np.allclose(f_out, output_num), (params, f_out, output_num)
@pytest.mark.skipif(
version.parse(np.__version__) < version.parse("2.0"),
reason="Legacy C-implementation did not check for runtime broadcast",
)
@pytest.mark.parametrize("func", (advanced_inc_subtensor1, advanced_set_subtensor1))
def test_advanced1_inc_runtime_broadcast(self, func):
y = matrix("y", dtype="float64", shape=(None, None))
x = ptb.zeros((10, 5))
idxs = np.repeat(np.arange(10), 2)
out = func(x, y, idxs)
f = function([y], out)
f(np.ones((20, 5))) # Fine
with pytest.raises(
ValueError,
match="Runtime broadcasting not allowed. AdvancedIncSubtensor1 was asked",
):
f(np.ones((1, 5)))
with pytest.raises(
ValueError,
match="Runtime broadcasting not allowed. AdvancedIncSubtensor1 was asked",
):
f(np.ones((20, 1)))
def test_adv_constant_arg(self): def test_adv_constant_arg(self):
# Test case provided (and bug detected, gh-607) by John Salvatier # Test case provided (and bug detected, gh-607) by John Salvatier
m = matrix("m") m = matrix("m")
...@@ -2398,7 +2429,11 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2398,7 +2429,11 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [2, 3] aivec_val = [2, 3]
self._compile_and_check( self._compile_and_check(
[admat, bdmat], [admat, bdmat],
[advanced_set_subtensor1(admat, bdmat, aivec_val)], [
advanced_set_subtensor1(
admat, specify_broadcastable(bdmat, 0), aivec_val
)
],
[admat_val, [[1, 2, 3, 4]]], [admat_val, [[1, 2, 3, 4]]],
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
) )
...@@ -2425,7 +2460,11 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2425,7 +2460,11 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [2, 3] aivec_val = [2, 3]
self._compile_and_check( self._compile_and_check(
[adtens4, bdtens4], [adtens4, bdtens4],
[advanced_set_subtensor1(adtens4, bdtens4, aivec_val)], [
advanced_set_subtensor1(
adtens4, specify_broadcastable(bdtens4, 0, 1, 2), aivec_val
)
],
[adtens4_val, [[[[1, 2, 3, 4, 5]]]]], [adtens4_val, [[[[1, 2, 3, 4, 5]]]]],
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
warn=False, warn=False,
...@@ -2476,7 +2515,11 @@ class TestInferShape(utt.InferShapeTester): ...@@ -2476,7 +2515,11 @@ class TestInferShape(utt.InferShapeTester):
aivec_val = [2, 3] aivec_val = [2, 3]
self._compile_and_check( self._compile_and_check(
[adtens4, bdtens4], [adtens4, bdtens4],
[advanced_set_subtensor1(adtens4, bdtens4, aivec_val)], [
advanced_set_subtensor1(
adtens4, specify_broadcastable(bdtens4, 1, 2), aivec_val
)
],
[adtens4_val, [[[[1, 2, 3, 4, 5]]], [[[6, 7, 8, 9, 10]]]]], [adtens4_val, [[[[1, 2, 3, 4, 5]]], [[[6, 7, 8, 9, 10]]]]],
AdvancedIncSubtensor1, AdvancedIncSubtensor1,
warn=False, warn=False,
...@@ -3028,3 +3071,29 @@ class TestBenchmarks: ...@@ -3028,3 +3071,29 @@ class TestBenchmarks:
) )
fn.vm.allow_gc = gc fn.vm.allow_gc = gc
benchmark(fn, x_values, idxs_values) benchmark(fn, x_values, idxs_values)
@pytest.mark.parametrize(
"static_shape", (False, True), ids=lambda x: f"static_shape={x}"
)
@pytest.mark.parametrize("gc", (False, True), ids=lambda x: f"gc={x}")
@pytest.mark.parametrize("func", (inc_subtensor, set_subtensor))
def test_advanced_incsubtensor1(self, func, static_shape, gc, benchmark):
x = vector("x", shape=(85 if static_shape else None,))
x_values = np.zeros((85,))
buffer = ptb.zeros_like(x)
y_values = np.random.normal(size=(85 * 11,))
idxs_values = np.arange(85).repeat(11)
# With static shape and constant indices we know all idxs are valid
# Reuse same buffer of zeros, to check we rather allocate twice than copy inside IncSubtensor
out1 = func(buffer[idxs_values], y_values)
out2 = func(buffer[idxs_values[::-1]], y_values)
fn = pytensor.function(
[x],
[pytensor.Out(out1, borrow=True), pytensor.Out(out2, borrow=True)],
on_unused_input="ignore",
trust_input=True,
)
fn.vm.allow_gc = gc
benchmark(fn, x_values)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论