提交 5bbfc964 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Forbid runtime broadcasting in Elemwise

上级 e8bd0d7d
......@@ -7,9 +7,17 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@jax_funcify.register(Elemwise)
def jax_funcify_Elemwise(op, **kwargs):
def jax_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op
return jax_funcify(scalar_op, **kwargs)
base_fn = jax_funcify(scalar_op, node=node, **kwargs)
def elemwise_fn(*inputs):
# ScalarVariables in JAX are passed as int/float.
# We wrap them in arrays just for the broadcast check
Elemwise._check_runtime_broadcast(node, tuple(map(jnp.asarray, inputs)))
return base_fn(*inputs)
return elemwise_fn
@jax_funcify.register(CAReduce)
......
......@@ -35,15 +35,20 @@ def compute_itershape(
with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False
):
with builder.if_else(builder.icmp_unsigned("==", length, one)) as (
with builder.if_else(
builder.or_(
builder.icmp_unsigned("==", length, one),
builder.icmp_unsigned("==", shape[i], one),
)
) as (
then,
otherwise,
):
with then:
msg = (
f"Incompatible shapes for input {j} and axis {i} of "
f"elemwise. Input {j} has shape 1, but is not statically "
"known to have shape 1, and thus not broadcastable."
"Runtime broadcasting not allowed. "
"One input had a distinct dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
with otherwise:
......
......@@ -132,6 +132,7 @@ from pytensor.tensor.shape import ( # noqa
shape_padaxis,
shape_padleft,
shape_padright,
specify_broadcastable,
specify_shape,
)
......
......@@ -19,9 +19,9 @@ from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import transfer_type, upcast
from pytensor.tensor import _get_vector_length, as_tensor_variable
from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.type import (
TensorType,
continuous_dtypes,
......@@ -740,9 +740,7 @@ class Elemwise(OpenMPOp):
# FIXME: This no longer calls the C implementation!
super().perform(node, inputs, output_storage)
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))):
if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
self._check_runtime_broadcast(node, inputs)
ufunc_args = inputs
ufunc_kwargs = {}
......@@ -818,18 +816,26 @@ class Elemwise(OpenMPOp):
else:
storage[0] = variable
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
if len(node.outputs) > 1:
from pytensor.tensor.exceptions import ShapeError
raise ShapeError(
"Multiple outputs are not supported by the default `Elemwise.infer_shape`"
@staticmethod
def _check_runtime_broadcast(node, inputs):
for dims_and_bcast in zip(
*[
zip(input.shape, sinput.type.broadcastable)
for input, sinput in zip(inputs, node.inputs)
]
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True)
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
from pytensor.tensor.extra_ops import broadcast_shape
# The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s
return [tuple(as_tensor_variable(s) for s in out_shape)]
out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True)
return [tuple(as_tensor_variable(s) for s in out_shape)] * len(node.outputs)
def _c_all(self, node, nodename, inames, onames, sub):
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
......@@ -1193,7 +1199,7 @@ class Elemwise(OpenMPOp):
return support_code
def c_code_cache_version_apply(self, node):
version = [14] # the version corresponding to the c code in this Op
version = [15] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend...
scalar_node = Apply(
......
......@@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub):
if index != "x":
# Initialize the variables associated to the jth loop
# jump = stride - adjust
# If the variable has size 1 in that dim, we set the stride to zero to
# emulate broadcasting
jump = f"({var}_stride{index}) - ({adjust})"
init += f"""
{var}_n{index} = PyArray_DIMS({var})[{index}];
{var}_stride{index} = ({var}_n{index} == 1)? 0 : PyArray_STRIDES({var})[{index}] / sizeof({dtype});
{var}_stride{index} = PyArray_STRIDES({var})[{index}] / sizeof({dtype});
{var}_jump{index}_{j} = {jump};
"""
adjust = f"{var}_n{index}*{var}_stride{index}"
......@@ -86,6 +84,14 @@ def make_checks(loop_orders, dtypes, sub):
# This loop builds multiple if conditions to verify that the
# dimensions of the inputs match, and the first one that is true
# raises an informative error message
runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. "
"One input had a distinct dimension length of 1, but was not marked as broadcastable: "
"(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). "
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
for matches in zip(*loop_orders):
to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"]
......@@ -93,81 +99,58 @@ def make_checks(loop_orders, dtypes, sub):
if len(to_compare) < 2:
continue
# Find first dimension size that is != 1
jl, xl = to_compare[-1]
non1size_dim_check = f"""
npy_intp non1size_dim{xl};
non1size_dim{xl} = """
for j, x in to_compare[:-1]:
non1size_dim_check += f"(%(lv{j})s_n{x} != 1) ? %(lv{j})s_n{x} : "
non1size_dim_check += f"%(lv{jl})s_n{xl};"
check += non1size_dim_check
# Check the nonsize1 dims match
# TODO: This is a bit inefficient because we are comparing one dimension against itself
j0, x0 = to_compare[0]
for j, x in to_compare[1:]:
check += f"""
if (non1size_dim{xl} != 1)
if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x})
{{
"""
for j, x in to_compare:
check += f"""
if ((%(lv{j})s_n{x} != non1size_dim{x}) && (%(lv{j})s_n{x} != 1))
if (%(lv{j0})s_n{x0} == 1 || %(lv{j})s_n{x} == 1)
{{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.",
PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}",
{j0},
{x0},
(long long int) %(lv{j0})s_n{x0},
{j},
{x},
(long long int) non1size_dim{x},
(long long int) %(lv{j})s_n{x}
);
}} else {{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
{j0},
{x0},
(long long int) %(lv{j0})s_n{x0},
{j},
{x},
(long long int) %(lv{j})s_n{x}
);
}}
%(fail)s
}}
"""
check += """
}
"""
return init % sub + check % sub
def compute_broadcast_dimensions(array_name: str, loop_orders, sub) -> str:
"""Create c_code to compute broadcasted dimensions of multiple arrays, arising from
Elemwise operations.
def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
"""Create c_code to compute the output dimensions of an Elemwise operation.
The code returned by this function populates the array `array_name`, but does not
initialize it.
TODO: We can decide to either specialize C code even further given the input types
or make it general, regardless of whether static broadcastable information is given
Note: We could specialize C code even further with the known static output shapes
"""
dims_c_code = ""
for i, candidates in enumerate(zip(*loop_orders)):
# TODO: Are candidates always either "x" or "i"? If that's the case we can
# simplify some logic here (e.g., we don't need to track the `idx`).
nonx_candidates = tuple(
(idx, c) for idx, c in enumerate(candidates) if c != "x"
)
# All inputs are known to be broadcastable
if not nonx_candidates:
dims_c_code += f"{array_name}[{i}] = 1;\n"
continue
# There is only one informative source of size
if len(nonx_candidates) == 1:
idx, candidate = nonx_candidates[0]
var = sub[f"lv{int(idx)}"]
# Borrow the length of the first non-broadcastable input dimension
for j, candidate in enumerate(candidates):
if candidate != "x":
var = sub[f"lv{int(j)}"]
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
continue
break
# If none is non-broadcastable, the output dimension has a length of 1
else: # no-break
dims_c_code += f"{array_name}[{i}] = 1;\n"
# In this case any non-size 1 variable will define the right size
dims_c_code += f"{array_name}[{i}] = "
for idx, candidate in nonx_candidates[:-1]:
var = sub[f"lv{int(idx)}"]
dims_c_code += f"({var}_n{candidate} != 1)? {var}_n{candidate}: "
idx, candidate = nonx_candidates[-1]
var = sub[f"lv{idx}"]
dims_c_code += f"{var}_n{candidate};\n"
return dims_c_code
......@@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
if type.startswith("PYTENSOR_COMPLEX"):
type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX")
nd = len(loop_orders[0])
init_dims = compute_broadcast_dimensions("dims", loop_orders, sub)
init_dims = compute_output_dims_lengths("dims", loop_orders, sub)
# TODO: it would be interesting to allocate the output in such a
# way that its contiguous dimensions match one of the input's
......@@ -359,7 +342,7 @@ def make_reordered_loop(
# Get the (sorted) total number of iterations of each loop
declare_totals = f"int init_totals[{nnested}];\n"
declare_totals += compute_broadcast_dimensions("init_totals", init_loop_orders, sub)
declare_totals += compute_output_dims_lengths("init_totals", init_loop_orders, sub)
# Sort totals to match the new order that was computed by sorting
# the loop vector. One integer variable per loop is declared.
......
......@@ -1439,7 +1439,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
_broadcast_assert = Assert(
"Could not broadcast dimensions. Broadcasting is only allowed along "
"axes that have a statically known length 1. Use `specify_shape` to "
"axes that have a statically known length 1. Use `specify_broadcastable` to "
"inform PyTensor of a known shape."
)
......
......@@ -4,6 +4,7 @@ import scipy.special
import pytensor
import pytensor.tensor as at
from pytensor.compile import get_mode
from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value
......@@ -14,6 +15,11 @@ from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_elemwise import TestElemwise
def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("JAX"))
def test_jax_Dimshuffle():
......
......@@ -9,6 +9,7 @@ import pytensor.tensor as at
import pytensor.tensor.inplace as ati
import pytensor.tensor.math as aem
from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad
......@@ -22,6 +23,7 @@ from tests.link.numba.test_basic import (
scalar_my_multi_out,
set_test_value,
)
from tests.tensor.test_elemwise import TestElemwise
rng = np.random.default_rng(42849)
......@@ -119,6 +121,10 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
compare_numba_and_py(out_fg, input_vals)
def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA"))
def test_elemwise_speed(benchmark):
x = at.dmatrix("y")
y = at.dvector("z")
......
......@@ -1671,12 +1671,7 @@ class TestLocalElemwiseAlloc:
(),
(),
),
pytest.param(
lambda x, y: at.mul(y, at.alloc(1, x)),
(),
(),
marks=pytest.mark.xfail(reason="Not implemented"),
),
(lambda x, y: at.mul(y, at.alloc(1, x)), (), ()),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)),
(
......
......@@ -607,8 +607,7 @@ class TestAlgebraicCanonizer:
((fx / fy) / fx, [fx, fy], [fxv, fyv], 1, "float32"),
((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"),
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
# must broadcast as their is a dimshuffle in the computation
# The broadcast leads to an extra elemwise to check compatibility
# must broadcast as there is a dimshuffle in the computation
((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"),
......
......@@ -428,12 +428,11 @@ class TestSameShape:
# could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any
# combination of the two.
assert not shape_feature.same_shape(x, o)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o)
@pytest.mark.parametrize(
"y_dim_0",
[2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))],
[2, None],
)
def test_vector_dim(self, y_dim_0):
x = at.tensor(dtype="floatX", shape=(2, None))
......
......@@ -18,7 +18,6 @@ from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import second
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import ShapeError
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import any as at_any
from pytensor.tensor.math import exp
......@@ -216,7 +215,6 @@ class TestBroadcast:
return np.asarray(np.random.random(shp), dtype=pytensor.config.floatX)
def with_linker(self, linker, op, type, rand_val):
for shape_info in ("complete", "only_broadcastable", "none"):
for xsh, ysh in [
((3, 5), (3, 5)),
((3, 5), (1, 5)),
......@@ -233,12 +231,6 @@ class TestBroadcast:
((2, 3, 4, 5), (1, 1, 1, 1)),
((), ()),
]:
if shape_info == "complete":
x_type = type(pytensor.config.floatX, shape=xsh)
y_type = type(pytensor.config.floatX, shape=ysh)
elif shape_info == "only_broadcastable":
# This condition is here for backwards compatibility, when the only
# type shape provided by PyTensor was broadcastable/non-broadcastable
x_type = type(
pytensor.config.floatX,
shape=tuple(s if s == 1 else None for s in xsh),
......@@ -247,9 +239,6 @@ class TestBroadcast:
pytensor.config.floatX,
shape=tuple(s if s == 1 else None for s in ysh),
)
else:
x_type = type(pytensor.config.floatX, shape=[None for _ in xsh])
y_type = type(pytensor.config.floatX, shape=[None for _ in ysh])
x = x_type("x")
y = y_type("y")
......@@ -267,13 +256,10 @@ class TestBroadcast:
x = x_type("x")
y = y_type("y")
e = op(aes.add)(x, y)
f = make_function(
copy(linker).accept(FunctionGraph([x, y], [e.shape]))
)
f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape])))
assert tuple(f(xv, yv)) == tuple(zv.shape)
def with_linker_inplace(self, linker, op, type, rand_val):
for shape_info in ("complete", "only_broadcastable", "none"):
for xsh, ysh in [
((5, 5), (5, 5)),
((5, 5), (1, 5)),
......@@ -284,12 +270,6 @@ class TestBroadcast:
((2, 3, 4, 5), (1, 1, 1, 1)),
((), ()),
]:
if shape_info == "complete":
x_type = type(pytensor.config.floatX, shape=xsh)
y_type = type(pytensor.config.floatX, shape=ysh)
elif shape_info == "only_broadcastable":
# This condition is here for backwards compatibility, when the only
# type shape provided by PyTensor was broadcastable/non-broadcastable
x_type = type(
pytensor.config.floatX,
shape=tuple(s if s == 1 else None for s in xsh),
......@@ -298,9 +278,6 @@ class TestBroadcast:
pytensor.config.floatX,
shape=tuple(s if s == 1 else None for s in ysh),
)
else:
x_type = type(pytensor.config.floatX, shape=[None for _ in xsh])
y_type = type(pytensor.config.floatX, shape=[None for _ in ysh])
x = x_type("x")
y = y_type("y")
......@@ -319,9 +296,7 @@ class TestBroadcast:
x = x_type("x")
y = y_type("y")
e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y)
f = make_function(
copy(linker).accept(FunctionGraph([x, y], [e.shape]))
)
f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape])))
xv = rand_val(xsh)
yv = rand_val(ysh)
zv = xv + yv
......@@ -775,32 +750,42 @@ class TestElemwise(unittest_tools.InferShapeTester):
g = pytensor.function([a, b, c, d, e, f], s, mode=Mode(linker="py"))
g(*[np.zeros(2**11, config.floatX) for i in range(6)])
def check_input_dimensions_match(self, mode):
"""Make sure that our input validation works correctly and doesn't
throw erroneous broadcast-based errors.
"""
@staticmethod
def check_runtime_shapes_error(mode):
"""Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
x_v = matrix("x")
m_v = vector("m")
x = np.array([[-1.32720483], [0.23442016]]).astype(config.floatX)
m = np.array([0.0, 0.0]).astype(config.floatX)
z_v = x_v - m_v
f = pytensor.function([x_v, m_v], z_v, mode=mode)
res = f(x, m)
# Test invalid broadcasting by either x or m
for x_sh, m_sh in [((2, 1), (3,)), ((2, 3), (1,))]:
x = np.ones(x_sh).astype(config.floatX)
m = np.zeros(m_sh).astype(config.floatX)
assert np.array_equal(res, x - m)
# This error is introduced by PyTensor, so it's the same across different backends
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(x, m)
def test_input_dimensions_match_python(self):
self.check_input_dimensions_match(Mode(linker="py"))
x = np.ones((2, 3)).astype(config.floatX)
m = np.zeros((1,)).astype(config.floatX)
x = np.ones((2, 4)).astype(config.floatX)
m = np.zeros((3,)).astype(config.floatX)
# This error is backend specific, and may have different types
with pytest.raises((ValueError, TypeError)):
f(x, m)
def test_runtime_shapes_error_python(self):
self.check_runtime_shapes_error(Mode(linker="py"))
@pytest.mark.skipif(
not pytensor.config.cxx,
reason="G++ not available, so we need to skip this test.",
)
def test_input_dimensions_match_c(self):
self.check_input_dimensions_match(Mode(linker="c"))
def test_runtime_shapes_error_c(self):
self.check_runtime_shapes_error(Mode(linker="c"))
def test_str(self):
op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None)
......@@ -825,7 +810,7 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1
assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1
def test_multi_output(self):
def test_infer_shape_multi_output(self):
class CustomElemwise(Elemwise):
def make_node(self, *args):
res = super().make_node(*args)
......@@ -839,14 +824,26 @@ class TestElemwise(unittest_tools.InferShapeTester):
],
)
z_1, z_2 = CustomElemwise(aes.add)(
as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(1))
)
custom_elemwise = CustomElemwise(aes.add)
z_1, z_2 = custom_elemwise(
as_tensor_variable(np.eye(1)),
as_tensor_variable(np.eye(1)),
)
in_1_shape = (aes.constant(1), aes.constant(1))
outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])
for out in outs:
assert out[0].eval() == 1
assert out[1].eval() == 1
with pytest.raises(ShapeError):
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])
z_1, z_2 = custom_elemwise(
as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(3))
)
in_2_shape = (aes.constant(3), aes.constant(3))
outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_2_shape])
for out in outs:
assert out[0].eval() == 3
assert out[1].eval() == 3
def test_shape_types(self):
x = tensor(dtype=np.float64, shape=(None, 1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论