提交 34eaaa53 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Forbid runtime broadcasting in Alloc

上级 69418200
...@@ -41,9 +41,10 @@ def jax_funcify_AllocEmpty(op, **kwargs): ...@@ -41,9 +41,10 @@ def jax_funcify_AllocEmpty(op, **kwargs):
@jax_funcify.register(Alloc) @jax_funcify.register(Alloc)
def jax_funcify_Alloc(op, **kwargs): def jax_funcify_Alloc(op, node, **kwargs):
def alloc(x, *shape): def alloc(x, *shape):
res = jnp.broadcast_to(x, shape) res = jnp.broadcast_to(x, shape)
Alloc._check_runtime_broadcast(node, jnp.asarray(x), res.shape)
return res return res
return alloc return alloc
......
...@@ -77,16 +77,24 @@ def numba_funcify_Alloc(op, node, **kwargs): ...@@ -77,16 +77,24 @@ def numba_funcify_Alloc(op, node, **kwargs):
" " * 4, " " * 4,
) )
check_runtime_broadcast = []
for i, val_static_dim in enumerate(node.inputs[0].type.shape[::-1]):
if val_static_dim is None:
check_runtime_broadcast.append(
f'if val.shape[{-i - 1}] == 1 and scalar_shape[{-i - 1}] != 1: raise ValueError("{Alloc._runtime_broadcast_error_msg}")'
)
check_runtime_broadcast_src = indent("\n".join(check_runtime_broadcast), " " * 4)
alloc_def_src = f""" alloc_def_src = f"""
def alloc(val, {", ".join(shape_var_names)}): def alloc(val, {", ".join(shape_var_names)}):
val_np = np.asarray(val) val_np = np.asarray(val)
{shapes_to_items_src} {shapes_to_items_src}
scalar_shape = {create_tuple_string(shape_var_item_names)} scalar_shape = {create_tuple_string(shape_var_item_names)}
{check_runtime_broadcast_src}
res = np.empty(scalar_shape, dtype=val_np.dtype) res = np.empty(scalar_shape, dtype=val_np.dtype)
res[...] = val_np res[...] = val_np
return res return res
""" """
alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env}) alloc_fn = compile_function_src(alloc_def_src, "alloc", {**globals(), **global_env})
return numba_basic.numba_njit(alloc_fn) return numba_basic.numba_njit(alloc_fn)
......
...@@ -1431,6 +1431,12 @@ class Alloc(COp): ...@@ -1431,6 +1431,12 @@ class Alloc(COp):
__props__ = () __props__ = ()
_runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. "
"The output of Alloc requires broadcasting a dimension of the input value, which was not marked as broadcastable. "
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
def make_node(self, value, *shape): def make_node(self, value, *shape):
value = as_tensor_variable(value) value = as_tensor_variable(value)
shape, static_shape = infer_static_shape(shape) shape, static_shape = infer_static_shape(shape)
...@@ -1468,10 +1474,21 @@ class Alloc(COp): ...@@ -1468,10 +1474,21 @@ class Alloc(COp):
otype = TensorType(dtype=value.dtype, shape=combined_static_shape) otype = TensorType(dtype=value.dtype, shape=combined_static_shape)
return Apply(self, [value] + shape, [otype()]) return Apply(self, [value] + shape, [otype()])
@staticmethod
def _check_runtime_broadcast(node, value, shape):
value_static_shape = node.inputs[0].type.shape
for v_static_dim, value_dim, out_dim in zip(
value_static_shape[::-1], value.shape[::-1], shape[::-1]
):
if v_static_dim is None and value_dim == 1 and out_dim != 1:
raise ValueError(Alloc._runtime_broadcast_error_msg)
def perform(self, node, inputs, out_): def perform(self, node, inputs, out_):
(out,) = out_ (out,) = out_
v = inputs[0] v = inputs[0]
sh = tuple([int(i) for i in inputs[1:]]) sh = tuple([int(i) for i in inputs[1:]])
self._check_runtime_broadcast(node, v, sh)
if out[0] is None or out[0].shape != sh: if out[0] is None or out[0].shape != sh:
if v.size == 1 and v.item() == 0: if v.size == 1 and v.item() == 0:
out[0] = np.zeros(sh, dtype=v.dtype) out[0] = np.zeros(sh, dtype=v.dtype)
...@@ -1484,12 +1501,19 @@ class Alloc(COp): ...@@ -1484,12 +1501,19 @@ class Alloc(COp):
def c_code(self, node, name, inp, out, sub): def c_code(self, node, name, inp, out, sub):
vv = inp[0] vv = inp[0]
ndim = len(inp[1:])
(zz,) = out (zz,) = out
fail = sub["fail"] fail = sub["fail"]
v_static_shape = node.inputs[0].type.shape
o_static_shape = node.outputs[0].type.shape
v_ndim = len(v_static_shape)
o_ndim = len(o_static_shape)
assert o_ndim == len(inp[1:])
# Declare variables
code = f""" code = f"""
npy_intp shape[{ndim}]; npy_intp shape[{o_ndim}];
int need_new_out;
""" """
# Initialize shape # Initialize shape
...@@ -1498,15 +1522,26 @@ class Alloc(COp): ...@@ -1498,15 +1522,26 @@ class Alloc(COp):
shape[{i}] = ((dtype_{shp_i}*) PyArray_DATA({shp_i}))[0]; shape[{i}] = ((dtype_{shp_i}*) PyArray_DATA({shp_i}))[0];
""" """
# Add checks for runtime broadcasting
for i, v_static_dim in enumerate(v_static_shape[::-1]):
if v_static_dim is None:
code += f"""
if (PyArray_DIMS({vv})[{v_ndim - i - 1}] == 1 && shape[{o_ndim - i - 1}] != 1)
{{
PyErr_Format(PyExc_ValueError, "{self._runtime_broadcast_error_msg}");
{fail}
}}
"""
code += f""" code += f"""
int need_new_out = (NULL == {zz}); need_new_out = (NULL == {zz});
for (int i = 0; i < {ndim}; i++) for (int i = 0; i < {o_ndim}; i++)
need_new_out = (need_new_out || (PyArray_DIMS({zz})[i] != shape[i])); need_new_out = (need_new_out || (PyArray_DIMS({zz})[i] != shape[i]));
if (need_new_out) if (need_new_out)
{{ {{
Py_XDECREF({zz}); Py_XDECREF({zz});
{zz} = (PyArrayObject*) PyArray_SimpleNew({ndim}, shape, PyArray_TYPE({vv})); {zz} = (PyArrayObject*) PyArray_SimpleNew({o_ndim}, shape, PyArray_TYPE({vv}));
if (!{zz}) if (!{zz})
{{ {{
PyErr_SetString(PyExc_MemoryError, "alloc failed"); PyErr_SetString(PyExc_MemoryError, "alloc failed");
...@@ -1522,7 +1557,7 @@ class Alloc(COp): ...@@ -1522,7 +1557,7 @@ class Alloc(COp):
return code return code
def c_code_cache_version(self): def c_code_cache_version(self):
return (3,) return (4,)
def infer_shape(self, fgraph, node, input_shapes): def infer_shape(self, fgraph, node, input_shapes):
return [node.inputs[1:]] return [node.inputs[1:]]
......
import numpy as np import numpy as np
import pytest import pytest
from pytensor.compile import get_mode
jax = pytest.importorskip("jax") jax = pytest.importorskip("jax")
import jax.errors import jax.errors
...@@ -12,6 +14,7 @@ from pytensor.graph.fg import FunctionGraph ...@@ -12,6 +14,7 @@ from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
from pytensor.tensor.type import iscalar, matrix, scalar, vector from pytensor.tensor.type import iscalar, matrix, scalar, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_basic import TestAlloc
def test_jax_Alloc(): def test_jax_Alloc():
...@@ -50,6 +53,10 @@ def test_jax_Alloc(): ...@@ -50,6 +53,10 @@ def test_jax_Alloc():
compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)]) compare_jax_and_py(x_fg, [np.ones(10, dtype=config.floatX)])
def test_alloc_runtime_broadcast():
TestAlloc.check_runtime_broadcast(get_mode("JAX"))
def test_jax_MakeVector(): def test_jax_MakeVector():
x = at.make_vector(1, 2, 3) x = at.make_vector(1, 2, 3)
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
......
...@@ -5,6 +5,7 @@ import pytensor.scalar as aes ...@@ -5,6 +5,7 @@ import pytensor.scalar as aes
import pytensor.tensor as at import pytensor.tensor as at
import pytensor.tensor.basic as atb import pytensor.tensor.basic as atb
from pytensor import config, function from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant from pytensor.graph.basic import Constant
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
...@@ -15,6 +16,7 @@ from tests.link.numba.test_basic import ( ...@@ -15,6 +16,7 @@ from tests.link.numba.test_basic import (
compare_shape_dtype, compare_shape_dtype,
set_test_value, set_test_value,
) )
from tests.tensor.test_basic import TestAlloc
pytest.importorskip("numba") pytest.importorskip("numba")
...@@ -49,6 +51,10 @@ def test_Alloc(v, shape): ...@@ -49,6 +51,10 @@ def test_Alloc(v, shape):
assert numba_res.shape == shape assert numba_res.shape == shape
def test_alloc_runtime_broadcast():
TestAlloc.check_runtime_broadcast(get_mode("NUMBA"))
def test_AllocEmpty(): def test_AllocEmpty():
x = at.empty((2, 3), dtype="float32") x = at.empty((2, 3), dtype="float32")
x_fg = FunctionGraph([], [x]) x_fg = FunctionGraph([], [x])
......
...@@ -719,6 +719,38 @@ class TestAlloc: ...@@ -719,6 +719,38 @@ class TestAlloc:
shared = staticmethod(pytensor.shared) shared = staticmethod(pytensor.shared)
allocs = [Alloc()] * 3 allocs = [Alloc()] * 3
@staticmethod
def check_allocs_in_fgraph(fgraph, n):
assert (
len([node for node in fgraph.apply_nodes if isinstance(node.op, Alloc)])
== n
)
@staticmethod
def check_runtime_broadcast(mode):
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
floatX = config.floatX
x_v = vector("x", shape=(None,))
out = alloc(x_v, 5, 3)
f = pytensor.function([x_v], out, mode=mode)
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
np.testing.assert_array_equal(
f(x=np.zeros((3,), dtype=floatX)),
np.zeros((5, 3), dtype=floatX),
)
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(x=np.zeros((1,), dtype=floatX))
out = alloc(specify_shape(x_v, (1,)), 5, 3)
f = pytensor.function([x_v], out, mode=mode)
TestAlloc.check_allocs_in_fgraph(f.maker.fgraph, 1)
np.testing.assert_array_equal(
f(x=np.zeros((1,), dtype=floatX)),
np.zeros((5, 3), dtype=floatX),
)
def setup_method(self): def setup_method(self):
self.rng = np.random.default_rng(seed=utt.fetch_seed()) self.rng = np.random.default_rng(seed=utt.fetch_seed())
...@@ -853,6 +885,8 @@ class TestAlloc: ...@@ -853,6 +885,8 @@ class TestAlloc:
def test_alloc_of_view_linker(self): def test_alloc_of_view_linker(self):
"""Check we can allocate a new array properly in the C linker when input is a view.""" """Check we can allocate a new array properly in the C linker when input is a view."""
floatX = config.floatX
x_v = vector("x", shape=(None,)) x_v = vector("x", shape=(None,))
dim_len = scalar("dim_len", dtype=int) dim_len = scalar("dim_len", dtype=int)
out = alloc(specify_shape(x_v, (1,)), 5, dim_len) out = alloc(specify_shape(x_v, (1,)), 5, dim_len)
...@@ -862,7 +896,14 @@ class TestAlloc: ...@@ -862,7 +896,14 @@ class TestAlloc:
f.maker.fgraph.outputs, [alloc(specify_shape(x_v, (1,)), 5, dim_len)] f.maker.fgraph.outputs, [alloc(specify_shape(x_v, (1,)), 5, dim_len)]
) )
np.testing.assert_array_equal(f(x=np.zeros((1,)), dim_len=3), np.zeros((5, 3))) np.testing.assert_array_equal(
f(x=np.zeros((1,), dtype=floatX), dim_len=3),
np.zeros((5, 3), dtype=floatX),
)
@pytest.mark.parametrize("mode", (Mode("py"), Mode("c")))
def test_runtime_broadcast(self, mode):
self.check_runtime_broadcast(mode)
def test_infer_shape(): def test_infer_shape():
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论