提交 5bbfc964 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Forbid runtime broadcasting in Elemwise

上级 e8bd0d7d
...@@ -7,9 +7,17 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad ...@@ -7,9 +7,17 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@jax_funcify.register(Elemwise) @jax_funcify.register(Elemwise)
def jax_funcify_Elemwise(op, **kwargs): def jax_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op scalar_op = op.scalar_op
return jax_funcify(scalar_op, **kwargs) base_fn = jax_funcify(scalar_op, node=node, **kwargs)
def elemwise_fn(*inputs):
# ScalarVariables in JAX are passed as int/float.
# We wrap them in arrays just for the broadcast check
Elemwise._check_runtime_broadcast(node, tuple(map(jnp.asarray, inputs)))
return base_fn(*inputs)
return elemwise_fn
@jax_funcify.register(CAReduce) @jax_funcify.register(CAReduce)
......
...@@ -35,15 +35,20 @@ def compute_itershape( ...@@ -35,15 +35,20 @@ def compute_itershape(
with builder.if_then( with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False builder.icmp_unsigned("!=", length, shape[i]), likely=False
): ):
with builder.if_else(builder.icmp_unsigned("==", length, one)) as ( with builder.if_else(
builder.or_(
builder.icmp_unsigned("==", length, one),
builder.icmp_unsigned("==", shape[i], one),
)
) as (
then, then,
otherwise, otherwise,
): ):
with then: with then:
msg = ( msg = (
f"Incompatible shapes for input {j} and axis {i} of " "Runtime broadcasting not allowed. "
f"elemwise. Input {j} has shape 1, but is not statically " "One input had a distinct dimension length of 1, but was not marked as broadcastable.\n"
"known to have shape 1, and thus not broadcastable." "If broadcasting was intended, use `specify_broadcastable` on the relevant input."
) )
ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
with otherwise: with otherwise:
......
...@@ -132,6 +132,7 @@ from pytensor.tensor.shape import ( # noqa ...@@ -132,6 +132,7 @@ from pytensor.tensor.shape import ( # noqa
shape_padaxis, shape_padaxis,
shape_padleft, shape_padleft,
shape_padright, shape_padright,
specify_broadcastable,
specify_shape, specify_shape,
) )
......
...@@ -19,9 +19,9 @@ from pytensor.scalar import get_scalar_type ...@@ -19,9 +19,9 @@ from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import transfer_type, upcast from pytensor.scalar.basic import transfer_type, upcast
from pytensor.tensor import _get_vector_length, as_tensor_variable
from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
continuous_dtypes, continuous_dtypes,
...@@ -740,9 +740,7 @@ class Elemwise(OpenMPOp): ...@@ -740,9 +740,7 @@ class Elemwise(OpenMPOp):
# FIXME: This no longer calls the C implementation! # FIXME: This no longer calls the C implementation!
super().perform(node, inputs, output_storage) super().perform(node, inputs, output_storage)
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))): self._check_runtime_broadcast(node, inputs)
if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
ufunc_args = inputs ufunc_args = inputs
ufunc_kwargs = {} ufunc_kwargs = {}
...@@ -818,18 +816,26 @@ class Elemwise(OpenMPOp): ...@@ -818,18 +816,26 @@ class Elemwise(OpenMPOp):
else: else:
storage[0] = variable storage[0] = variable
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]: @staticmethod
if len(node.outputs) > 1: def _check_runtime_broadcast(node, inputs):
from pytensor.tensor.exceptions import ShapeError for dims_and_bcast in zip(
*[
raise ShapeError( zip(input.shape, sinput.type.broadcastable)
"Multiple outputs are not supported by the default `Elemwise.infer_shape`" for input, sinput in zip(inputs, node.inputs)
) ]
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True) def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
from pytensor.tensor.extra_ops import broadcast_shape
# The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True)
return [tuple(as_tensor_variable(s) for s in out_shape)] return [tuple(as_tensor_variable(s) for s in out_shape)] * len(node.outputs)
def _c_all(self, node, nodename, inames, onames, sub): def _c_all(self, node, nodename, inames, onames, sub):
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
...@@ -1193,7 +1199,7 @@ class Elemwise(OpenMPOp): ...@@ -1193,7 +1199,7 @@ class Elemwise(OpenMPOp):
return support_code return support_code
def c_code_cache_version_apply(self, node): def c_code_cache_version_apply(self, node):
version = [14] # the version corresponding to the c code in this Op version = [15] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply( scalar_node = Apply(
......
...@@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub): ...@@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub):
if index != "x": if index != "x":
# Initialize the variables associated to the jth loop # Initialize the variables associated to the jth loop
# jump = stride - adjust # jump = stride - adjust
# If the variable has size 1 in that dim, we set the stride to zero to
# emulate broadcasting
jump = f"({var}_stride{index}) - ({adjust})" jump = f"({var}_stride{index}) - ({adjust})"
init += f""" init += f"""
{var}_n{index} = PyArray_DIMS({var})[{index}]; {var}_n{index} = PyArray_DIMS({var})[{index}];
{var}_stride{index} = ({var}_n{index} == 1)? 0 : PyArray_STRIDES({var})[{index}] / sizeof({dtype}); {var}_stride{index} = PyArray_STRIDES({var})[{index}] / sizeof({dtype});
{var}_jump{index}_{j} = {jump}; {var}_jump{index}_{j} = {jump};
""" """
adjust = f"{var}_n{index}*{var}_stride{index}" adjust = f"{var}_n{index}*{var}_stride{index}"
...@@ -86,6 +84,14 @@ def make_checks(loop_orders, dtypes, sub): ...@@ -86,6 +84,14 @@ def make_checks(loop_orders, dtypes, sub):
# This loop builds multiple if conditions to verify that the # This loop builds multiple if conditions to verify that the
# dimensions of the inputs match, and the first one that is true # dimensions of the inputs match, and the first one that is true
# raises an informative error message # raises an informative error message
runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. "
"One input had a distinct dimension length of 1, but was not marked as broadcastable: "
"(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). "
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
for matches in zip(*loop_orders): for matches in zip(*loop_orders):
to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"] to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"]
...@@ -93,81 +99,58 @@ def make_checks(loop_orders, dtypes, sub): ...@@ -93,81 +99,58 @@ def make_checks(loop_orders, dtypes, sub):
if len(to_compare) < 2: if len(to_compare) < 2:
continue continue
# Find first dimension size that is != 1 j0, x0 = to_compare[0]
jl, xl = to_compare[-1] for j, x in to_compare[1:]:
non1size_dim_check = f"""
npy_intp non1size_dim{xl};
non1size_dim{xl} = """
for j, x in to_compare[:-1]:
non1size_dim_check += f"(%(lv{j})s_n{x} != 1) ? %(lv{j})s_n{x} : "
non1size_dim_check += f"%(lv{jl})s_n{xl};"
check += non1size_dim_check
# Check the nonsize1 dims match
# TODO: This is a bit inefficient because we are comparing one dimension against itself
check += f"""
if (non1size_dim{xl} != 1)
{{
"""
for j, x in to_compare:
check += f""" check += f"""
if ((%(lv{j})s_n{x} != non1size_dim{x}) && (%(lv{j})s_n{x} != 1)) if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x})
{{
if (%(lv{j0})s_n{x0} == 1 || %(lv{j})s_n{x} == 1)
{{ {{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.", PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}",
{x}, {j0},
(long long int) non1size_dim{x}, {x0},
(long long int) %(lv{j0})s_n{x0},
{j},
{x},
(long long int) %(lv{j})s_n{x}
);
}} else {{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
{j0},
{x0},
(long long int) %(lv{j0})s_n{x0},
{j}, {j},
{x}, {x},
(long long int) %(lv{j})s_n{x} (long long int) %(lv{j})s_n{x}
); );
%(fail)s
}} }}
""" %(fail)s
check += """ }}
}
""" """
return init % sub + check % sub return init % sub + check % sub
def compute_broadcast_dimensions(array_name: str, loop_orders, sub) -> str: def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
"""Create c_code to compute broadcasted dimensions of multiple arrays, arising from """Create c_code to compute the output dimensions of an Elemwise operation.
Elemwise operations.
The code returned by this function populates the array `array_name`, but does not The code returned by this function populates the array `array_name`, but does not
initialize it. initialize it.
TODO: We can decide to either specialize C code even further given the input types Note: We could specialize C code even further with the known static output shapes
or make it general, regardless of whether static broadcastable information is given
""" """
dims_c_code = "" dims_c_code = ""
for i, candidates in enumerate(zip(*loop_orders)): for i, candidates in enumerate(zip(*loop_orders)):
# TODO: Are candidates always either "x" or "i"? If that's the case we can # Borrow the length of the first non-broadcastable input dimension
# simplify some logic here (e.g., we don't need to track the `idx`). for j, candidate in enumerate(candidates):
nonx_candidates = tuple( if candidate != "x":
(idx, c) for idx, c in enumerate(candidates) if c != "x" var = sub[f"lv{int(j)}"]
) dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
break
# All inputs are known to be broadcastable # If none is non-broadcastable, the output dimension has a length of 1
if not nonx_candidates: else: # no-break
dims_c_code += f"{array_name}[{i}] = 1;\n" dims_c_code += f"{array_name}[{i}] = 1;\n"
continue
# There is only one informative source of size
if len(nonx_candidates) == 1:
idx, candidate = nonx_candidates[0]
var = sub[f"lv{int(idx)}"]
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
continue
# In this case any non-size 1 variable will define the right size
dims_c_code += f"{array_name}[{i}] = "
for idx, candidate in nonx_candidates[:-1]:
var = sub[f"lv{int(idx)}"]
dims_c_code += f"({var}_n{candidate} != 1)? {var}_n{candidate}: "
idx, candidate = nonx_candidates[-1]
var = sub[f"lv{idx}"]
dims_c_code += f"{var}_n{candidate};\n"
return dims_c_code return dims_c_code
...@@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): ...@@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
if type.startswith("PYTENSOR_COMPLEX"): if type.startswith("PYTENSOR_COMPLEX"):
type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX") type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX")
nd = len(loop_orders[0]) nd = len(loop_orders[0])
init_dims = compute_broadcast_dimensions("dims", loop_orders, sub) init_dims = compute_output_dims_lengths("dims", loop_orders, sub)
# TODO: it would be interesting to allocate the output in such a # TODO: it would be interesting to allocate the output in such a
# way that its contiguous dimensions match one of the input's # way that its contiguous dimensions match one of the input's
...@@ -359,7 +342,7 @@ def make_reordered_loop( ...@@ -359,7 +342,7 @@ def make_reordered_loop(
# Get the (sorted) total number of iterations of each loop # Get the (sorted) total number of iterations of each loop
declare_totals = f"int init_totals[{nnested}];\n" declare_totals = f"int init_totals[{nnested}];\n"
declare_totals += compute_broadcast_dimensions("init_totals", init_loop_orders, sub) declare_totals += compute_output_dims_lengths("init_totals", init_loop_orders, sub)
# Sort totals to match the new order that was computed by sorting # Sort totals to match the new order that was computed by sorting
# the loop vector. One integer variable per loop is declared. # the loop vector. One integer variable per loop is declared.
......
...@@ -1439,7 +1439,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"): ...@@ -1439,7 +1439,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
_broadcast_assert = Assert( _broadcast_assert = Assert(
"Could not broadcast dimensions. Broadcasting is only allowed along " "Could not broadcast dimensions. Broadcasting is only allowed along "
"axes that have a statically known length 1. Use `specify_shape` to " "axes that have a statically known length 1. Use `specify_broadcastable` to "
"inform PyTensor of a known shape." "inform PyTensor of a known shape."
) )
......
...@@ -4,6 +4,7 @@ import scipy.special ...@@ -4,6 +4,7 @@ import scipy.special
import pytensor import pytensor
import pytensor.tensor as at import pytensor.tensor as at
from pytensor.compile import get_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
...@@ -14,6 +15,11 @@ from pytensor.tensor.math import sum as at_sum ...@@ -14,6 +15,11 @@ from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_elemwise import TestElemwise
def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("JAX"))
def test_jax_Dimshuffle(): def test_jax_Dimshuffle():
......
...@@ -9,6 +9,7 @@ import pytensor.tensor as at ...@@ -9,6 +9,7 @@ import pytensor.tensor as at
import pytensor.tensor.inplace as ati import pytensor.tensor.inplace as ati
import pytensor.tensor.math as aem import pytensor.tensor.math as aem
from pytensor import config, function from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.compile.ops import deep_copy_op from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad from pytensor.gradient import grad
...@@ -22,6 +23,7 @@ from tests.link.numba.test_basic import ( ...@@ -22,6 +23,7 @@ from tests.link.numba.test_basic import (
scalar_my_multi_out, scalar_my_multi_out,
set_test_value, set_test_value,
) )
from tests.tensor.test_elemwise import TestElemwise
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -119,6 +121,10 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): ...@@ -119,6 +121,10 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
compare_numba_and_py(out_fg, input_vals) compare_numba_and_py(out_fg, input_vals)
def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA"))
def test_elemwise_speed(benchmark): def test_elemwise_speed(benchmark):
x = at.dmatrix("y") x = at.dmatrix("y")
y = at.dvector("z") y = at.dvector("z")
......
...@@ -1671,12 +1671,7 @@ class TestLocalElemwiseAlloc: ...@@ -1671,12 +1671,7 @@ class TestLocalElemwiseAlloc:
(), (),
(), (),
), ),
pytest.param( (lambda x, y: at.mul(y, at.alloc(1, x)), (), ()),
lambda x, y: at.mul(y, at.alloc(1, x)),
(),
(),
marks=pytest.mark.xfail(reason="Not implemented"),
),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)), (lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)), (lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)),
( (
......
...@@ -607,8 +607,7 @@ class TestAlgebraicCanonizer: ...@@ -607,8 +607,7 @@ class TestAlgebraicCanonizer:
((fx / fy) / fx, [fx, fy], [fxv, fyv], 1, "float32"), ((fx / fy) / fx, [fx, fy], [fxv, fyv], 1, "float32"),
((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"), ((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"),
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"), ((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
# must broadcast as their is a dimshuffle in the computation # must broadcast as there is a dimshuffle in the computation
# The broadcast leads to an extra elemwise to check compatibility
((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"), ((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc] # topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"), ((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"),
......
...@@ -428,12 +428,11 @@ class TestSameShape: ...@@ -428,12 +428,11 @@ class TestSameShape:
# could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any # could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any
# combination of the two. # combination of the two.
assert not shape_feature.same_shape(x, o) assert not shape_feature.same_shape(x, o)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o) assert not shape_feature.same_shape(y, o)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"y_dim_0", "y_dim_0",
[2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))], [2, None],
) )
def test_vector_dim(self, y_dim_0): def test_vector_dim(self, y_dim_0):
x = at.tensor(dtype="floatX", shape=(2, None)) x = at.tensor(dtype="floatX", shape=(2, None))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论