提交 65826e7e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Luciano Paz

Handle invalid BroadcastTo shape in C backend

上级 24b67a86
...@@ -1643,6 +1643,11 @@ class BroadcastTo(COp): ...@@ -1643,6 +1643,11 @@ class BroadcastTo(COp):
shape, static_shape = at.infer_static_shape(shape) shape, static_shape = at.infer_static_shape(shape)
if len(shape) < a.ndim:
raise ValueError(
f"Broadcast target shape has {len(shape)} dims, which is shorter than input with {a.ndim} dims"
)
out = TensorType(dtype=a.type.dtype, shape=static_shape)() out = TensorType(dtype=a.type.dtype, shape=static_shape)()
# Attempt to prevent in-place operations on this view-based output # Attempt to prevent in-place operations on this view-based output
...@@ -1686,9 +1691,12 @@ class BroadcastTo(COp): ...@@ -1686,9 +1691,12 @@ class BroadcastTo(COp):
return [node.inputs[1:]] return [node.inputs[1:]]
def c_code(self, node, name, inputs, outputs, sub): def c_code(self, node, name, inputs, outputs, sub):
inp_dims = node.inputs[0].ndim
out_dims = node.outputs[0].ndim
new_dims = out_dims - inp_dims
(x, *shape) = inputs (x, *shape) = inputs
(out,) = outputs (out,) = outputs
ndims = len(shape)
fail = sub["fail"] fail = sub["fail"]
# TODO: Could just use `PyArray_Return`, no? # TODO: Could just use `PyArray_Return`, no?
...@@ -1701,20 +1709,34 @@ class BroadcastTo(COp): ...@@ -1701,20 +1709,34 @@ class BroadcastTo(COp):
src = ( src = (
""" """
npy_intp itershape[%(ndims)s] = {%(dims_array)s}; npy_intp itershape[%(out_dims)s] = {%(dims_array)s};
NpyIter *iter;
PyArrayObject *ops[1] = {%(x)s}; PyArrayObject *ops[1] = {%(x)s};
npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK; npy_uint32 flags = NPY_ITER_MULTI_INDEX | NPY_ITER_REFS_OK | NPY_ITER_ZEROSIZE_OK;
npy_uint32 op_flags[1] = {NPY_ITER_READONLY}; npy_uint32 op_flags[1] = {NPY_ITER_READONLY};
PyArray_Descr *op_dtypes[1] = {NULL}; PyArray_Descr *op_dtypes[1] = {NULL};
int oa_ndim = %(ndims)s; int oa_ndim = %(out_dims)s;
int* op_axes[1] = {NULL}; int* op_axes[1] = {NULL};
npy_intp buffersize = 0; npy_intp buffersize = 0;
NpyIter *iter = NpyIter_AdvancedNew( for(int i = 0; i < %(inp_dims)s; i++)
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize {
if ((PyArray_DIMS(%(x)s)[i] != 1) && (PyArray_DIMS(%(x)s)[i] != itershape[i + %(new_dims)s]))
{
PyErr_Format(PyExc_ValueError,
"Shape mismatch in broadcast_to: target shape[%%i] = %%lld is incompatible with input shape = %%lld.",
i,
(long long int) itershape[i + %(new_dims)s],
(long long int) PyArray_DIMS(%(x)s)[i]
); );
%(fail)s
}
}
iter = NpyIter_AdvancedNew(
1, ops, flags, NPY_CORDER, NPY_NO_CASTING, op_flags, op_dtypes, oa_ndim, op_axes, itershape, buffersize
);
%(out)s = NpyIter_GetIterView(iter, 0); %(out)s = NpyIter_GetIterView(iter, 0);
if(%(out)s == NULL){ if(%(out)s == NULL){
...@@ -1733,7 +1755,7 @@ class BroadcastTo(COp): ...@@ -1733,7 +1755,7 @@ class BroadcastTo(COp):
return src return src
def c_code_cache_version(self): def c_code_cache_version(self):
return (1,) return (2,)
broadcast_to_ = BroadcastTo() broadcast_to_ = BroadcastTo()
......
...@@ -1253,41 +1253,52 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1253,41 +1253,52 @@ class TestBroadcastTo(utt.InferShapeTester):
@pytest.mark.parametrize("linker", ["cvm", "py"]) @pytest.mark.parametrize("linker", ["cvm", "py"])
def test_perform(self, linker): def test_perform(self, linker):
a = pytensor.shared(5) a = pytensor.shared(np.full((3, 1, 1), 5))
s_0 = iscalar("s_0")
s_1 = iscalar("s_1") s_1 = iscalar("s_1")
shape = (s_1, 1) shape = (s_0, s_1, 1)
bcast_res = broadcast_to(a, shape) bcast_res = broadcast_to(a, shape)
assert bcast_res.broadcastable == (False, True) assert bcast_res.broadcastable == (False, False, True)
bcast_fn = pytensor.function( bcast_fn = pytensor.function(
[s_1], bcast_res, mode=Mode(optimizer=None, linker=linker) [s_0, s_1], bcast_res, mode=Mode(optimizer=None, linker=linker)
) )
bcast_fn.vm.allow_gc = False bcast_fn.vm.allow_gc = False
bcast_at = bcast_fn(4) bcast_at = bcast_fn(3, 4)
bcast_np = np.broadcast_to(5, (4, 1)) bcast_np = np.broadcast_to(5, (3, 4, 1))
assert np.array_equal(bcast_at, bcast_np) assert np.array_equal(bcast_at, bcast_np)
with pytest.raises(ValueError):
bcast_fn(5, 4)
if linker != "py":
bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0] bcast_var = bcast_fn.maker.fgraph.outputs[0].owner.inputs[0]
bcast_in = bcast_fn.vm.storage_map[a] bcast_in = bcast_fn.vm.storage_map[a]
bcast_out = bcast_fn.vm.storage_map[bcast_var] bcast_out = bcast_fn.vm.storage_map[bcast_var]
if linker != "py":
assert np.shares_memory(bcast_out[0], bcast_in[0]) assert np.shares_memory(bcast_out[0], bcast_in[0])
def test_make_node_error_handling(self):
with pytest.raises(
ValueError,
match="Broadcast target shape has 1 dims, which is shorter than input with 2 dims",
):
broadcast_to(at.zeros((3, 4)), (5,))
@pytest.mark.skipif( @pytest.mark.skipif(
not config.cxx, reason="G++ not available, so we need to skip this test." not config.cxx, reason="G++ not available, so we need to skip this test."
) )
def test_memory_leak(self): @pytest.mark.parametrize("valid", (True, False))
def test_memory_leak(self, valid):
import gc import gc
import tracemalloc import tracemalloc
from pytensor.link.c.cvm import CVM from pytensor.link.c.cvm import CVM
n = 100_000 n = 100_000
x = pytensor.shared(np.ones(n, dtype=np.float64)) x = pytensor.shared(np.ones((1, n), dtype=np.float64))
y = broadcast_to(x, (5, n)) y = broadcast_to(x, (5, n))
f = pytensor.function([], y, mode=Mode(optimizer=None, linker="cvm")) f = pytensor.function([], y, mode=Mode(optimizer=None, linker="cvm"))
...@@ -1303,8 +1314,17 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1303,8 +1314,17 @@ class TestBroadcastTo(utt.InferShapeTester):
blocks_last = None blocks_last = None
block_diffs = [] block_diffs = []
for i in range(1, 50): for i in range(1, 50):
x.set_value(np.ones(n)) if valid:
x.set_value(np.ones((1, n)))
_ = f() _ = f()
else:
x.set_value(np.ones((2, n)))
try:
_ = f()
except ValueError:
pass
else:
raise RuntimeError("Should have failed")
_ = gc.collect() _ = gc.collect()
blocks_i, _ = tracemalloc.get_traced_memory() blocks_i, _ = tracemalloc.get_traced_memory()
if blocks_last is not None: if blocks_last is not None:
...@@ -1313,7 +1333,7 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1313,7 +1333,7 @@ class TestBroadcastTo(utt.InferShapeTester):
blocks_last = blocks_i blocks_last = blocks_i
tracemalloc.stop() tracemalloc.stop()
assert np.allclose(np.mean(block_diffs), 0) assert np.all(np.array(block_diffs) <= (0 + 1e-8))
@pytest.mark.parametrize( @pytest.mark.parametrize(
"fn,input_dims", "fn,input_dims",
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论