提交 5bbfc964 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Forbid runtime broadcasting in Elemwise

上级 e8bd0d7d
...@@ -7,9 +7,17 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad ...@@ -7,9 +7,17 @@ from pytensor.tensor.special import LogSoftmax, Softmax, SoftmaxGrad
@jax_funcify.register(Elemwise) @jax_funcify.register(Elemwise)
def jax_funcify_Elemwise(op, **kwargs): def jax_funcify_Elemwise(op, node, **kwargs):
scalar_op = op.scalar_op scalar_op = op.scalar_op
return jax_funcify(scalar_op, **kwargs) base_fn = jax_funcify(scalar_op, node=node, **kwargs)
def elemwise_fn(*inputs):
# ScalarVariables in JAX are passed as int/float.
# We wrap them in arrays just for the broadcast check
Elemwise._check_runtime_broadcast(node, tuple(map(jnp.asarray, inputs)))
return base_fn(*inputs)
return elemwise_fn
@jax_funcify.register(CAReduce) @jax_funcify.register(CAReduce)
......
...@@ -35,15 +35,20 @@ def compute_itershape( ...@@ -35,15 +35,20 @@ def compute_itershape(
with builder.if_then( with builder.if_then(
builder.icmp_unsigned("!=", length, shape[i]), likely=False builder.icmp_unsigned("!=", length, shape[i]), likely=False
): ):
with builder.if_else(builder.icmp_unsigned("==", length, one)) as ( with builder.if_else(
builder.or_(
builder.icmp_unsigned("==", length, one),
builder.icmp_unsigned("==", shape[i], one),
)
) as (
then, then,
otherwise, otherwise,
): ):
with then: with then:
msg = ( msg = (
f"Incompatible shapes for input {j} and axis {i} of " "Runtime broadcasting not allowed. "
f"elemwise. Input {j} has shape 1, but is not statically " "One input had a distinct dimension length of 1, but was not marked as broadcastable.\n"
"known to have shape 1, and thus not broadcastable." "If broadcasting was intended, use `specify_broadcastable` on the relevant input."
) )
ctx.call_conv.return_user_exc(builder, ValueError, (msg,)) ctx.call_conv.return_user_exc(builder, ValueError, (msg,))
with otherwise: with otherwise:
......
...@@ -132,6 +132,7 @@ from pytensor.tensor.shape import ( # noqa ...@@ -132,6 +132,7 @@ from pytensor.tensor.shape import ( # noqa
shape_padaxis, shape_padaxis,
shape_padleft, shape_padleft,
shape_padright, shape_padright,
specify_broadcastable,
specify_shape, specify_shape,
) )
......
...@@ -19,9 +19,9 @@ from pytensor.scalar import get_scalar_type ...@@ -19,9 +19,9 @@ from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity from pytensor.scalar.basic import identity as scalar_identity
from pytensor.scalar.basic import transfer_type, upcast from pytensor.scalar.basic import transfer_type, upcast
from pytensor.tensor import _get_vector_length, as_tensor_variable
from pytensor.tensor import elemwise_cgen as cgen from pytensor.tensor import elemwise_cgen as cgen
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.basic import _get_vector_length, as_tensor_variable
from pytensor.tensor.type import ( from pytensor.tensor.type import (
TensorType, TensorType,
continuous_dtypes, continuous_dtypes,
...@@ -740,9 +740,7 @@ class Elemwise(OpenMPOp): ...@@ -740,9 +740,7 @@ class Elemwise(OpenMPOp):
# FIXME: This no longer calls the C implementation! # FIXME: This no longer calls the C implementation!
super().perform(node, inputs, output_storage) super().perform(node, inputs, output_storage)
for d, dim_shapes in enumerate(zip(*(i.shape for i in inputs))): self._check_runtime_broadcast(node, inputs)
if len(set(dim_shapes) - {1}) > 1:
raise ValueError(f"Shapes on dimension {d} do not match: {dim_shapes}")
ufunc_args = inputs ufunc_args = inputs
ufunc_kwargs = {} ufunc_kwargs = {}
...@@ -818,18 +816,26 @@ class Elemwise(OpenMPOp): ...@@ -818,18 +816,26 @@ class Elemwise(OpenMPOp):
else: else:
storage[0] = variable storage[0] = variable
def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]: @staticmethod
if len(node.outputs) > 1: def _check_runtime_broadcast(node, inputs):
from pytensor.tensor.exceptions import ShapeError for dims_and_bcast in zip(
*[
raise ShapeError( zip(input.shape, sinput.type.broadcastable)
"Multiple outputs are not supported by the default `Elemwise.infer_shape`" for input, sinput in zip(inputs, node.inputs)
) ]
):
if any(d != 1 for d, _ in dims_and_bcast) and (1, False) in dims_and_bcast:
raise ValueError(
"Runtime broadcasting not allowed. "
"At least one input has a distinct dimension length of 1, but was not marked as broadcastable.\n"
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
out_shape = pytensor.tensor.broadcast_shape(*i_shapes, arrays_are_shapes=True) def infer_shape(self, fgraph, node, i_shapes) -> List[Tuple[TensorVariable, ...]]:
from pytensor.tensor.extra_ops import broadcast_shape
# The `as_tensor_variable` should convert `ScalarType`s to `TensorType`s out_shape = broadcast_shape(*i_shapes, arrays_are_shapes=True)
return [tuple(as_tensor_variable(s) for s in out_shape)] return [tuple(as_tensor_variable(s) for s in out_shape)] * len(node.outputs)
def _c_all(self, node, nodename, inames, onames, sub): def _c_all(self, node, nodename, inames, onames, sub):
# Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code` # Some `Op`s directly call `Elemwise._c_all` or `Elemwise.c_code`
...@@ -1193,7 +1199,7 @@ class Elemwise(OpenMPOp): ...@@ -1193,7 +1199,7 @@ class Elemwise(OpenMPOp):
return support_code return support_code
def c_code_cache_version_apply(self, node): def c_code_cache_version_apply(self, node):
version = [14] # the version corresponding to the c code in this Op version = [15] # the version corresponding to the c code in this Op
# now we insert versions for the ops on which we depend... # now we insert versions for the ops on which we depend...
scalar_node = Apply( scalar_node = Apply(
......
...@@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub): ...@@ -66,12 +66,10 @@ def make_checks(loop_orders, dtypes, sub):
if index != "x": if index != "x":
# Initialize the variables associated to the jth loop # Initialize the variables associated to the jth loop
# jump = stride - adjust # jump = stride - adjust
# If the variable has size 1 in that dim, we set the stride to zero to
# emulate broadcasting
jump = f"({var}_stride{index}) - ({adjust})" jump = f"({var}_stride{index}) - ({adjust})"
init += f""" init += f"""
{var}_n{index} = PyArray_DIMS({var})[{index}]; {var}_n{index} = PyArray_DIMS({var})[{index}];
{var}_stride{index} = ({var}_n{index} == 1)? 0 : PyArray_STRIDES({var})[{index}] / sizeof({dtype}); {var}_stride{index} = PyArray_STRIDES({var})[{index}] / sizeof({dtype});
{var}_jump{index}_{j} = {jump}; {var}_jump{index}_{j} = {jump};
""" """
adjust = f"{var}_n{index}*{var}_stride{index}" adjust = f"{var}_n{index}*{var}_stride{index}"
...@@ -86,6 +84,14 @@ def make_checks(loop_orders, dtypes, sub): ...@@ -86,6 +84,14 @@ def make_checks(loop_orders, dtypes, sub):
# This loop builds multiple if conditions to verify that the # This loop builds multiple if conditions to verify that the
# dimensions of the inputs match, and the first one that is true # dimensions of the inputs match, and the first one that is true
# raises an informative error message # raises an informative error message
runtime_broadcast_error_msg = (
"Runtime broadcasting not allowed. "
"One input had a distinct dimension length of 1, but was not marked as broadcastable: "
"(input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld). "
"If broadcasting was intended, use `specify_broadcastable` on the relevant input."
)
for matches in zip(*loop_orders): for matches in zip(*loop_orders):
to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"] to_compare = [(j, x) for j, x in enumerate(matches) if x != "x"]
...@@ -93,81 +99,58 @@ def make_checks(loop_orders, dtypes, sub): ...@@ -93,81 +99,58 @@ def make_checks(loop_orders, dtypes, sub):
if len(to_compare) < 2: if len(to_compare) < 2:
continue continue
# Find first dimension size that is != 1 j0, x0 = to_compare[0]
jl, xl = to_compare[-1] for j, x in to_compare[1:]:
non1size_dim_check = f"""
npy_intp non1size_dim{xl};
non1size_dim{xl} = """
for j, x in to_compare[:-1]:
non1size_dim_check += f"(%(lv{j})s_n{x} != 1) ? %(lv{j})s_n{x} : "
non1size_dim_check += f"%(lv{jl})s_n{xl};"
check += non1size_dim_check
# Check the nonsize1 dims match
# TODO: This is a bit inefficient because we are comparing one dimension against itself
check += f"""
if (non1size_dim{xl} != 1)
{{
"""
for j, x in to_compare:
check += f""" check += f"""
if ((%(lv{j})s_n{x} != non1size_dim{x}) && (%(lv{j})s_n{x} != 1)) if (%(lv{j0})s_n{x0} != %(lv{j})s_n{x})
{{
if (%(lv{j0})s_n{x0} == 1 || %(lv{j})s_n{x} == 1)
{{ {{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch. One other input has shape[%%i] = %%lld, but input[%%i].shape[%%i] = %%lld.", PyErr_Format(PyExc_ValueError, "{runtime_broadcast_error_msg}",
{x}, {j0},
(long long int) non1size_dim{x}, {x0},
(long long int) %(lv{j0})s_n{x0},
{j},
{x},
(long long int) %(lv{j})s_n{x}
);
}} else {{
PyErr_Format(PyExc_ValueError, "Input dimension mismatch: (input[%%i].shape[%%i] = %%lld, input[%%i].shape[%%i] = %%lld)",
{j0},
{x0},
(long long int) %(lv{j0})s_n{x0},
{j}, {j},
{x}, {x},
(long long int) %(lv{j})s_n{x} (long long int) %(lv{j})s_n{x}
); );
%(fail)s
}} }}
""" %(fail)s
check += """ }}
}
""" """
return init % sub + check % sub return init % sub + check % sub
def compute_broadcast_dimensions(array_name: str, loop_orders, sub) -> str: def compute_output_dims_lengths(array_name: str, loop_orders, sub) -> str:
"""Create c_code to compute broadcasted dimensions of multiple arrays, arising from """Create c_code to compute the output dimensions of an Elemwise operation.
Elemwise operations.
The code returned by this function populates the array `array_name`, but does not The code returned by this function populates the array `array_name`, but does not
initialize it. initialize it.
TODO: We can decide to either specialize C code even further given the input types Note: We could specialize C code even further with the known static output shapes
or make it general, regardless of whether static broadcastable information is given
""" """
dims_c_code = "" dims_c_code = ""
for i, candidates in enumerate(zip(*loop_orders)): for i, candidates in enumerate(zip(*loop_orders)):
# TODO: Are candidates always either "x" or "i"? If that's the case we can # Borrow the length of the first non-broadcastable input dimension
# simplify some logic here (e.g., we don't need to track the `idx`). for j, candidate in enumerate(candidates):
nonx_candidates = tuple( if candidate != "x":
(idx, c) for idx, c in enumerate(candidates) if c != "x" var = sub[f"lv{int(j)}"]
) dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
break
# All inputs are known to be broadcastable # If none is non-broadcastable, the output dimension has a length of 1
if not nonx_candidates: else: # no-break
dims_c_code += f"{array_name}[{i}] = 1;\n" dims_c_code += f"{array_name}[{i}] = 1;\n"
continue
# There is only one informative source of size
if len(nonx_candidates) == 1:
idx, candidate = nonx_candidates[0]
var = sub[f"lv{int(idx)}"]
dims_c_code += f"{array_name}[{i}] = {var}_n{candidate};\n"
continue
# In this case any non-size 1 variable will define the right size
dims_c_code += f"{array_name}[{i}] = "
for idx, candidate in nonx_candidates[:-1]:
var = sub[f"lv{int(idx)}"]
dims_c_code += f"({var}_n{candidate} != 1)? {var}_n{candidate}: "
idx, candidate = nonx_candidates[-1]
var = sub[f"lv{idx}"]
dims_c_code += f"{var}_n{candidate};\n"
return dims_c_code return dims_c_code
...@@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"): ...@@ -186,7 +169,7 @@ def make_alloc(loop_orders, dtype, sub, fortran="0"):
if type.startswith("PYTENSOR_COMPLEX"): if type.startswith("PYTENSOR_COMPLEX"):
type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX") type = type.replace("PYTENSOR_COMPLEX", "NPY_COMPLEX")
nd = len(loop_orders[0]) nd = len(loop_orders[0])
init_dims = compute_broadcast_dimensions("dims", loop_orders, sub) init_dims = compute_output_dims_lengths("dims", loop_orders, sub)
# TODO: it would be interesting to allocate the output in such a # TODO: it would be interesting to allocate the output in such a
# way that its contiguous dimensions match one of the input's # way that its contiguous dimensions match one of the input's
...@@ -359,7 +342,7 @@ def make_reordered_loop( ...@@ -359,7 +342,7 @@ def make_reordered_loop(
# Get the (sorted) total number of iterations of each loop # Get the (sorted) total number of iterations of each loop
declare_totals = f"int init_totals[{nnested}];\n" declare_totals = f"int init_totals[{nnested}];\n"
declare_totals += compute_broadcast_dimensions("init_totals", init_loop_orders, sub) declare_totals += compute_output_dims_lengths("init_totals", init_loop_orders, sub)
# Sort totals to match the new order that was computed by sorting # Sort totals to match the new order that was computed by sorting
# the loop vector. One integer variable per loop is declared. # the loop vector. One integer variable per loop is declared.
......
...@@ -1439,7 +1439,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"): ...@@ -1439,7 +1439,7 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
_broadcast_assert = Assert( _broadcast_assert = Assert(
"Could not broadcast dimensions. Broadcasting is only allowed along " "Could not broadcast dimensions. Broadcasting is only allowed along "
"axes that have a statically known length 1. Use `specify_shape` to " "axes that have a statically known length 1. Use `specify_broadcastable` to "
"inform PyTensor of a known shape." "inform PyTensor of a known shape."
) )
......
...@@ -4,6 +4,7 @@ import scipy.special ...@@ -4,6 +4,7 @@ import scipy.special
import pytensor import pytensor
import pytensor.tensor as at import pytensor.tensor as at
from pytensor.compile import get_mode
from pytensor.configdefaults import config from pytensor.configdefaults import config
from pytensor.graph.fg import FunctionGraph from pytensor.graph.fg import FunctionGraph
from pytensor.graph.op import get_test_value from pytensor.graph.op import get_test_value
...@@ -14,6 +15,11 @@ from pytensor.tensor.math import sum as at_sum ...@@ -14,6 +15,11 @@ from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax from pytensor.tensor.special import SoftmaxGrad, log_softmax, softmax
from pytensor.tensor.type import matrix, tensor, vector from pytensor.tensor.type import matrix, tensor, vector
from tests.link.jax.test_basic import compare_jax_and_py from tests.link.jax.test_basic import compare_jax_and_py
from tests.tensor.test_elemwise import TestElemwise
def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("JAX"))
def test_jax_Dimshuffle(): def test_jax_Dimshuffle():
......
...@@ -9,6 +9,7 @@ import pytensor.tensor as at ...@@ -9,6 +9,7 @@ import pytensor.tensor as at
import pytensor.tensor.inplace as ati import pytensor.tensor.inplace as ati
import pytensor.tensor.math as aem import pytensor.tensor.math as aem
from pytensor import config, function from pytensor import config, function
from pytensor.compile import get_mode
from pytensor.compile.ops import deep_copy_op from pytensor.compile.ops import deep_copy_op
from pytensor.compile.sharedvalue import SharedVariable from pytensor.compile.sharedvalue import SharedVariable
from pytensor.gradient import grad from pytensor.gradient import grad
...@@ -22,6 +23,7 @@ from tests.link.numba.test_basic import ( ...@@ -22,6 +23,7 @@ from tests.link.numba.test_basic import (
scalar_my_multi_out, scalar_my_multi_out,
set_test_value, set_test_value,
) )
from tests.tensor.test_elemwise import TestElemwise
rng = np.random.default_rng(42849) rng = np.random.default_rng(42849)
...@@ -119,6 +121,10 @@ def test_Elemwise(inputs, input_vals, output_fn, exc): ...@@ -119,6 +121,10 @@ def test_Elemwise(inputs, input_vals, output_fn, exc):
compare_numba_and_py(out_fg, input_vals) compare_numba_and_py(out_fg, input_vals)
def test_elemwise_runtime_shape_error():
TestElemwise.check_runtime_shapes_error(get_mode("NUMBA"))
def test_elemwise_speed(benchmark): def test_elemwise_speed(benchmark):
x = at.dmatrix("y") x = at.dmatrix("y")
y = at.dvector("z") y = at.dvector("z")
......
...@@ -1671,12 +1671,7 @@ class TestLocalElemwiseAlloc: ...@@ -1671,12 +1671,7 @@ class TestLocalElemwiseAlloc:
(), (),
(), (),
), ),
pytest.param( (lambda x, y: at.mul(y, at.alloc(1, x)), (), ()),
lambda x, y: at.mul(y, at.alloc(1, x)),
(),
(),
marks=pytest.mark.xfail(reason="Not implemented"),
),
(lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)), (lambda x, y: at.mul(at.alloc(x, 15, 1), y), (15, 1), (15, 1)),
(lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)), (lambda x, y: at.mul(at.alloc(x, 15, 2), y), (15, 2), (15, 2)),
( (
......
...@@ -607,8 +607,7 @@ class TestAlgebraicCanonizer: ...@@ -607,8 +607,7 @@ class TestAlgebraicCanonizer:
((fx / fy) / fx, [fx, fy], [fxv, fyv], 1, "float32"), ((fx / fy) / fx, [fx, fy], [fxv, fyv], 1, "float32"),
((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"), ((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"),
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"), ((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
# must broadcast as their is a dimshuffle in the computation # must broadcast as there is a dimshuffle in the computation
# The broadcast leads to an extra elemwise to check compatibility
((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"), ((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc] # topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"), ((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"),
......
...@@ -428,12 +428,11 @@ class TestSameShape: ...@@ -428,12 +428,11 @@ class TestSameShape:
# could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any # could have shapes `(1,)` and/or `(n,)`, where `n != 1`, or any
# combination of the two. # combination of the two.
assert not shape_feature.same_shape(x, o) assert not shape_feature.same_shape(x, o)
# The following case isn't implemented
assert not shape_feature.same_shape(y, o) assert not shape_feature.same_shape(y, o)
@pytest.mark.parametrize( @pytest.mark.parametrize(
"y_dim_0", "y_dim_0",
[2, pytest.param(None, marks=pytest.mark.xfail(reason="Not implemented"))], [2, None],
) )
def test_vector_dim(self, y_dim_0): def test_vector_dim(self, y_dim_0):
x = at.tensor(dtype="floatX", shape=(2, None)) x = at.tensor(dtype="floatX", shape=(2, None))
......
...@@ -18,7 +18,6 @@ from pytensor.link.c.basic import CLinker, OpWiseCLinker ...@@ -18,7 +18,6 @@ from pytensor.link.c.basic import CLinker, OpWiseCLinker
from pytensor.tensor import as_tensor_variable from pytensor.tensor import as_tensor_variable
from pytensor.tensor.basic import second from pytensor.tensor.basic import second
from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise from pytensor.tensor.elemwise import CAReduce, DimShuffle, Elemwise
from pytensor.tensor.exceptions import ShapeError
from pytensor.tensor.math import all as at_all from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import any as at_any from pytensor.tensor.math import any as at_any
from pytensor.tensor.math import exp from pytensor.tensor.math import exp
...@@ -216,117 +215,93 @@ class TestBroadcast: ...@@ -216,117 +215,93 @@ class TestBroadcast:
return np.asarray(np.random.random(shp), dtype=pytensor.config.floatX) return np.asarray(np.random.random(shp), dtype=pytensor.config.floatX)
def with_linker(self, linker, op, type, rand_val): def with_linker(self, linker, op, type, rand_val):
for shape_info in ("complete", "only_broadcastable", "none"): for xsh, ysh in [
for xsh, ysh in [ ((3, 5), (3, 5)),
((3, 5), (3, 5)), ((3, 5), (1, 5)),
((3, 5), (1, 5)), ((3, 5), (3, 1)),
((3, 5), (3, 1)), ((1, 5), (5, 1)),
((1, 5), (5, 1)), ((1, 1), (1, 1)),
((1, 1), (1, 1)), ((self.openmp_minsize,), (self.openmp_minsize,)),
((self.openmp_minsize,), (self.openmp_minsize,)), (
( (self.openmp_minsize_sqrt, self.openmp_minsize_sqrt),
(self.openmp_minsize_sqrt, self.openmp_minsize_sqrt), (self.openmp_minsize_sqrt, self.openmp_minsize_sqrt),
(self.openmp_minsize_sqrt, self.openmp_minsize_sqrt), ),
), ((2, 3, 4, 5), (2, 3, 4, 5)),
((2, 3, 4, 5), (2, 3, 4, 5)), ((2, 3, 4, 5), (1, 3, 1, 5)),
((2, 3, 4, 5), (1, 3, 1, 5)), ((2, 3, 4, 5), (1, 1, 1, 1)),
((2, 3, 4, 5), (1, 1, 1, 1)), ((), ()),
((), ()), ]:
]: x_type = type(
if shape_info == "complete": pytensor.config.floatX,
x_type = type(pytensor.config.floatX, shape=xsh) shape=tuple(s if s == 1 else None for s in xsh),
y_type = type(pytensor.config.floatX, shape=ysh) )
elif shape_info == "only_broadcastable": y_type = type(
# This condition is here for backwards compatibility, when the only pytensor.config.floatX,
# type shape provided by PyTensor was broadcastable/non-broadcastable shape=tuple(s if s == 1 else None for s in ysh),
x_type = type( )
pytensor.config.floatX,
shape=tuple(s if s == 1 else None for s in xsh), x = x_type("x")
) y = y_type("y")
y_type = type( e = op(aes.add)(x, y)
pytensor.config.floatX, f = make_function(copy(linker).accept(FunctionGraph([x, y], [e])))
shape=tuple(s if s == 1 else None for s in ysh), xv = rand_val(xsh)
) yv = rand_val(ysh)
else: zv = xv + yv
x_type = type(pytensor.config.floatX, shape=[None for _ in xsh])
y_type = type(pytensor.config.floatX, shape=[None for _ in ysh]) unittest_tools.assert_allclose(f(xv, yv), zv)
# test Elemwise.infer_shape
# the Shape op don't implement c_code!
if isinstance(linker, PerformLinker):
x = x_type("x") x = x_type("x")
y = y_type("y") y = y_type("y")
e = op(aes.add)(x, y) e = op(aes.add)(x, y)
f = make_function(copy(linker).accept(FunctionGraph([x, y], [e]))) f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape])))
xv = rand_val(xsh) assert tuple(f(xv, yv)) == tuple(zv.shape)
yv = rand_val(ysh)
zv = xv + yv
unittest_tools.assert_allclose(f(xv, yv), zv) def with_linker_inplace(self, linker, op, type, rand_val):
for xsh, ysh in [
((5, 5), (5, 5)),
((5, 5), (1, 5)),
((5, 5), (5, 1)),
((1, 1), (1, 1)),
((2, 3, 4, 5), (2, 3, 4, 5)),
((2, 3, 4, 5), (1, 3, 1, 5)),
((2, 3, 4, 5), (1, 1, 1, 1)),
((), ()),
]:
x_type = type(
pytensor.config.floatX,
shape=tuple(s if s == 1 else None for s in xsh),
)
y_type = type(
pytensor.config.floatX,
shape=tuple(s if s == 1 else None for s in ysh),
)
# test Elemwise.infer_shape x = x_type("x")
# the Shape op don't implement c_code! y = y_type("y")
if isinstance(linker, PerformLinker): e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y)
x = x_type("x") f = make_function(copy(linker).accept(FunctionGraph([x, y], [e])))
y = y_type("y") xv = rand_val(xsh)
e = op(aes.add)(x, y) yv = rand_val(ysh)
f = make_function( zv = xv + yv
copy(linker).accept(FunctionGraph([x, y], [e.shape]))
)
assert tuple(f(xv, yv)) == tuple(zv.shape)
def with_linker_inplace(self, linker, op, type, rand_val): f(xv, yv)
for shape_info in ("complete", "only_broadcastable", "none"):
for xsh, ysh in [
((5, 5), (5, 5)),
((5, 5), (1, 5)),
((5, 5), (5, 1)),
((1, 1), (1, 1)),
((2, 3, 4, 5), (2, 3, 4, 5)),
((2, 3, 4, 5), (1, 3, 1, 5)),
((2, 3, 4, 5), (1, 1, 1, 1)),
((), ()),
]:
if shape_info == "complete":
x_type = type(pytensor.config.floatX, shape=xsh)
y_type = type(pytensor.config.floatX, shape=ysh)
elif shape_info == "only_broadcastable":
# This condition is here for backwards compatibility, when the only
# type shape provided by PyTensor was broadcastable/non-broadcastable
x_type = type(
pytensor.config.floatX,
shape=tuple(s if s == 1 else None for s in xsh),
)
y_type = type(
pytensor.config.floatX,
shape=tuple(s if s == 1 else None for s in ysh),
)
else:
x_type = type(pytensor.config.floatX, shape=[None for _ in xsh])
y_type = type(pytensor.config.floatX, shape=[None for _ in ysh])
assert (xv == zv).all()
# test Elemwise.infer_shape
# the Shape op don't implement c_code!
if isinstance(linker, PerformLinker):
x = x_type("x") x = x_type("x")
y = y_type("y") y = y_type("y")
e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y) e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y)
f = make_function(copy(linker).accept(FunctionGraph([x, y], [e]))) f = make_function(copy(linker).accept(FunctionGraph([x, y], [e.shape])))
xv = rand_val(xsh) xv = rand_val(xsh)
yv = rand_val(ysh) yv = rand_val(ysh)
zv = xv + yv zv = xv + yv
assert xv.shape == zv.shape
f(xv, yv) assert tuple(f(xv, yv)) == zv.shape
assert (xv == zv).all()
# test Elemwise.infer_shape
# the Shape op don't implement c_code!
if isinstance(linker, PerformLinker):
x = x_type("x")
y = y_type("y")
e = op(aes.Add(aes.transfer_type(0)), {0: 0})(x, y)
f = make_function(
copy(linker).accept(FunctionGraph([x, y], [e.shape]))
)
xv = rand_val(xsh)
yv = rand_val(ysh)
zv = xv + yv
assert xv.shape == zv.shape
assert tuple(f(xv, yv)) == zv.shape
def test_perform(self): def test_perform(self):
self.with_linker(PerformLinker(), self.op, self.type, self.rand_val) self.with_linker(PerformLinker(), self.op, self.type, self.rand_val)
...@@ -775,32 +750,42 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -775,32 +750,42 @@ class TestElemwise(unittest_tools.InferShapeTester):
g = pytensor.function([a, b, c, d, e, f], s, mode=Mode(linker="py")) g = pytensor.function([a, b, c, d, e, f], s, mode=Mode(linker="py"))
g(*[np.zeros(2**11, config.floatX) for i in range(6)]) g(*[np.zeros(2**11, config.floatX) for i in range(6)])
def check_input_dimensions_match(self, mode): @staticmethod
"""Make sure that our input validation works correctly and doesn't def check_runtime_shapes_error(mode):
throw erroneous broadcast-based errors. """Check we emmit a clear error when runtime broadcasting would occur according to Numpy rules."""
"""
x_v = matrix("x") x_v = matrix("x")
m_v = vector("m") m_v = vector("m")
x = np.array([[-1.32720483], [0.23442016]]).astype(config.floatX)
m = np.array([0.0, 0.0]).astype(config.floatX)
z_v = x_v - m_v z_v = x_v - m_v
f = pytensor.function([x_v, m_v], z_v, mode=mode) f = pytensor.function([x_v, m_v], z_v, mode=mode)
res = f(x, m) # Test invalid broadcasting by either x or m
for x_sh, m_sh in [((2, 1), (3,)), ((2, 3), (1,))]:
x = np.ones(x_sh).astype(config.floatX)
m = np.zeros(m_sh).astype(config.floatX)
# This error is introduced by PyTensor, so it's the same across different backends
with pytest.raises(ValueError, match="Runtime broadcasting not allowed"):
f(x, m)
x = np.ones((2, 3)).astype(config.floatX)
m = np.zeros((1,)).astype(config.floatX)
assert np.array_equal(res, x - m) x = np.ones((2, 4)).astype(config.floatX)
m = np.zeros((3,)).astype(config.floatX)
# This error is backend specific, and may have different types
with pytest.raises((ValueError, TypeError)):
f(x, m)
def test_input_dimensions_match_python(self): def test_runtime_shapes_error_python(self):
self.check_input_dimensions_match(Mode(linker="py")) self.check_runtime_shapes_error(Mode(linker="py"))
@pytest.mark.skipif( @pytest.mark.skipif(
not pytensor.config.cxx, not pytensor.config.cxx,
reason="G++ not available, so we need to skip this test.", reason="G++ not available, so we need to skip this test.",
) )
def test_input_dimensions_match_c(self): def test_runtime_shapes_error_c(self):
self.check_input_dimensions_match(Mode(linker="c")) self.check_runtime_shapes_error(Mode(linker="c"))
def test_str(self): def test_str(self):
op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None) op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None)
...@@ -825,7 +810,7 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -825,7 +810,7 @@ class TestElemwise(unittest_tools.InferShapeTester):
assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1 assert pytensor.get_underlying_scalar_constant(res_shape[0][0]) == 1
assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1 assert pytensor.get_underlying_scalar_constant(res_shape[0][1]) == 1
def test_multi_output(self): def test_infer_shape_multi_output(self):
class CustomElemwise(Elemwise): class CustomElemwise(Elemwise):
def make_node(self, *args): def make_node(self, *args):
res = super().make_node(*args) res = super().make_node(*args)
...@@ -839,14 +824,26 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -839,14 +824,26 @@ class TestElemwise(unittest_tools.InferShapeTester):
], ],
) )
z_1, z_2 = CustomElemwise(aes.add)( custom_elemwise = CustomElemwise(aes.add)
as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(1))
)
z_1, z_2 = custom_elemwise(
as_tensor_variable(np.eye(1)),
as_tensor_variable(np.eye(1)),
)
in_1_shape = (aes.constant(1), aes.constant(1)) in_1_shape = (aes.constant(1), aes.constant(1))
outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape])
for out in outs:
assert out[0].eval() == 1
assert out[1].eval() == 1
with pytest.raises(ShapeError): z_1, z_2 = custom_elemwise(
z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_1_shape]) as_tensor_variable(np.eye(1)), as_tensor_variable(np.eye(3))
)
in_2_shape = (aes.constant(3), aes.constant(3))
outs = z_1.owner.op.infer_shape(None, z_1.owner, [in_1_shape, in_2_shape])
for out in outs:
assert out[0].eval() == 3
assert out[1].eval() == 3
def test_shape_types(self): def test_shape_types(self):
x = tensor(dtype=np.float64, shape=(None, 1)) x = tensor(dtype=np.float64, shape=(None, 1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论