提交 34eaaa53 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Forbid runtime broadcasting in Alloc

上级 69418200
......@@ -41,9 +41,10 @@ def jax_funcify_AllocEmpty(op, **kwargs):
@jax_funcify.register(Alloc)
def jax_funcify_Alloc(op, **kwargs):
def jax_funcify_Alloc(op, node, **kwargs):
def alloc(x, *shape):
res = jnp.broadcast_to(x, shape)
Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
return res
return alloc
......
......@@ -77,16 +77,24 @@ def numba_funcify_Alloc(op, node, **kwargs):
" " * 4,
)
check_runtime_broadcast = []
for i, val_static_dim in enumerate(node.inputs[0].type.shape[::-1]):
if val_static_dim is None:
check_runtime_broadcast.append(
f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")'
)
check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4)
alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}):
val_np = np.asarray(val)
{shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)}
{check_runtime_broadcast_src}
res = np.empty(scalar_shape, dtype=val_np.dtype)
res[...] = val_np
return res
"""
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})
return numba_basic.numba_njit(alloc_fn)
......
......@@ -1431,6 +1431,12 @@ class Alloc(COp):
__props__ = ()
_runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. "
"The output of Alloc requires broadcasting a dimension of the input value, which was not marked as broadcastable. "
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
def make_node(self, value, *shape):
value = as_tensor_variable(value)
shape, static_shape = infer_static_shape(shape)
......@@ -1468,10 +1474,21 @@ class Alloc(COp):
otype = TensorType(dtype=value.dtype, shape=combined_static_shape)
return Apply(self, [value] + shape, [otype()])
@staticmethod
def _check_runtime_broadcast(node, value, shape):
value_static_shape = node.inputs[0].type.shape
for v_static_dim, value_dim, out_dim in zip(
value_static_shape[::-1], value.shape[::-1], shape[::-1]
):
if v_static_dim is None and value_dim == 1 and out_dim != 1:
raise ValueError(Alloc._runtime_broadcast_error_msg)
def perform(self, node, inputs, out_):
(out,) = out_
v = inputs[0]
sh = tuple([int(i) for i in inputs[1:]])
self._check_runtime_broadcast(node, v, sh)
if out[0] is None or out[0].shape != sh:
if v.size == 1 and v.item() == 0:
out[0] = np.zeros(sh, dtype=v.dtype)
......@@ -1484,12 +1501,19 @@ class Alloc(COp):
def c_code(self, node, name, inp, out, sub):
vv = inp[0]
ndim = len(inp[1:])
(zz,) = out
fail = sub["fail"]
v_static_shape = node.inputs[0].type.shape
o_static_shape = node.outputs[0].type.shape
v_ndim = len(v_static_shape)
o_ndim = len(o_static_shape)
assert o_ndim == len(inp[1:])
# Declare variables
code = f"""
npy_intp shape[{ndim}];
npy_intp shape[{o_ndim}];
int need_new_out;
"""
# Initialize shape
......@@ -1498,15 +1522,26 @@ class Alloc(COp):
shape[{i}] = ((dtype_{shp_i}*) PyArray_DATA({shp_i}))[0];
"""
# Add checks for runtime broadcasting
for i, v_static_dim in enumerate(v_static_shape[::-1]):
if v_static_dim is None:
code += f"""
if (PyArray_DIMS({vv})[{v_ndim - i - 1}] == 1 && shape[{o_ndim - i - 1}] != 1)
{{
PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
{fail}
}}
"""
code += f"""
int need_new_out = (NULL == {zz});
for (int i = 0; i < {ndim}; i++)
need_new_out = (NULL == {zz});
for (int i = 0; i < {o_ndim}; i++)
need_new_out = (need_new_out || (PyArray_DIMS({zz})[i] != shape[i]));
if (need_new_out)
{{
Py_XDECREF({zz});
{zz} = (PyArrayObject*) PyArray_SimpleNew({ndim}, shape, PyArray_TYPE({vv}));
{zz} = (PyArrayObject*) PyArray_SimpleNew({o_ndim}, shape, PyArray_TYPE({vv}));
if (!{zz})
{{
PyErr_SetString(PyExc_MemoryError, "alloc failed");
......@@ -1522,7 +1557,7 @@ class Alloc(COp):
return code
def c_code_cache_version(self):
return (3,)
return (4,)
def infer_shape(self, fgraph, node, input_shapes):
return [node.inputs[1:]]
......
import numpy as np
import pytest
from pytensor.compile import get_mode
jax = pytest.importorskip("jax")
import jax.errors
......@@ -12,6 +14,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
from pytensor.tensor.type import iscalar, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_basic import TestAlloc
def test_jax_Alloc():
......@@ -50,6 +53,10 @@ def test_jax_Alloc():
compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)])
def test_alloc_runtime_broadcast():
TestAlloc.check_runtime_broadcast(get_mode("JAX"))
def test_jax_MakeVector():
x = at.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x])
......
......@@ -5,6 +5,7 @@ import pytensor.scalar as aes
import pytensor.tensor as at
import pytensor.tensor.basic as atb
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph
......@@ -15,6 +16,7 @@ from tests.link.numba.test_basic import (
compare_shape_dtype,
set_test_value,
)
from tests.tensor.test_basic import TestAlloc
pytest.importorskip("numba")
......@@ -49,6 +51,10 @@ def test_Alloc(v, shape):
assert numba_res.shape == shape
def test_alloc_runtime_broadcast():
TestAlloc.check_runtime_broadcast(get_mode("NUMBA"))
def test_AllocEmpty():
x = at.empty((2, 3), dtype="float32")
x_fg = FunctionGraph([], [x])
......
......@@ -719,6 +719,38 @@ class TestAlloc:
shared = staticmethod(pytensor.shared)
allocs = [Alloc()] * 3
@staticmethod
def check_allocs_in_fgraph(fgraph, n):
assert (
len([node for node in fgraph.apply_nodes if isinstance(node.op, Alloc)])
== n
)
@staticmethod
def check_runtime_broadcast(mode):
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
floatX = config.floatX
x_v = vector("x", shape=(None,))
out = alloc(x_v, 5, 3)
f = pytensor.function([x_v], out, mode=mode)
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
np.testing.assert_array_equal(
f(x=np.zeros((3,), dtype=floatX)),
np.zeros((5, 3), dtype=floatX),
)
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(x=np.zeros((1,), dtype=floatX))
out = alloc(specify_shape(x_v, (1,)), 5, 3)
f = pytensor.function([x_v], out, mode=mode)
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
np.testing.assert_array_equal(
f(x=np.zeros((1,), dtype=floatX)),
np.zeros((5, 3), dtype=floatX),
)
def setup_method(self):
self.rng = np.random.default_rng(seed=utt.fetch_seed())
......@@ -853,6 +885,8 @@ class TestAlloc:
def test_alloc_of_view_linker(self):
"""Check we can allocate a new array properly in the C linker when input is a view."""
floatX = config.floatX
x_v = vector("x", shape=(None,))
dim_len = scalar("dim_len", dtype=int)
out = alloc(specify_shape(x_v, (1,)), 5, dim_len)
......@@ -862,7 +896,14 @@ class TestAlloc:
f.maker.fgraph.outputs, [alloc(specify_shape(x_v, (1,)), 5, dim_len)]
)
np.testing.assert_array_equal(f(x=np.zeros((1,)), dim_len=3), np.zeros((5, 3)))
np.testing.assert_array_equal(
f(x=np.zeros((1,), dtype=floatX), dim_len=3),
np.zeros((5, 3), dtype=floatX),
)
@pytest.mark.parametrize("mode", (Mode("py"), Mode("c")))
def test_runtime_broadcast(self, mode):
self.check_runtime_broadcast(mode)
def test_infer_shape():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论