Unverified 提交 df4183d5 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: GitHub

Use static-only broadcasting rules to compute shape of broadcasting (#345)

上级 b9c4f20d
from collections.abc import Collection
from functools import reduce
from typing import Iterable, Set, Tuple, Union
import numpy as np
import numpy.core.numeric
from numpy.core.multiarray import normalize_axis_index
import pytensor
......@@ -14,7 +12,7 @@ from pytensor.gradient import (
disconnected_type,
grad_undefined,
)
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations
from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
......@@ -23,12 +21,12 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import Assert
from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast
from pytensor.scalar.basic import Composite
from pytensor.tensor import basic as at
from pytensor.tensor import get_vector_length
from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import all as at_all
from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import eq as pt_eq
from pytensor.tensor.math import ge, lt, maximum, minimum, prod
from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
......@@ -536,7 +534,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False):
if assert_nonneg:
assert_op = Assert("Input to bincount has negative values!")
x = assert_op(x, at_all(x >= 0))
x = assert_op(x, pt_all(x >= 0))
max_value = at.cast(x.max() + 1, "int64")
......@@ -1436,6 +1434,13 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
return RavelMultiIndex(mode=mode, order=order)(*args)
_broadcast_assert = Assert(
"Could not broadcast dimensions. Broadcasting is only allowed along "
"axes that have a statically known length 1. Use `specify_shape` to "
"inform PyTensor of a known shape."
)
def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
"""Compute the shape resulting from broadcasting arrays.
......@@ -1510,119 +1515,45 @@ def broadcast_shape_iter(
result_dims = []
for dim_shapes in zip(*array_shapes):
# Get the shapes in this dimension that are not definitively
# broadcastable (i.e. not symbolically known to be broadcastable)
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
# Get the shapes in this dimension that are not broadcastable
# (i.e. not symbolically known to be broadcastable)
non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
if len(maybe_non_bcast_shapes) == 0:
if len(non_bcast_shapes) == 0:
# Every shape was broadcastable in this dimension
result_dims.append(one_at)
elif len(maybe_non_bcast_shapes) == 1:
elif len(non_bcast_shapes) == 1:
# Only one shape might not be broadcastable in this dimension
result_dims.extend(maybe_non_bcast_shapes)
result_dims.extend(non_bcast_shapes)
else:
# More than one shape might not be broadcastable in this dimension
nonconst_nb_shapes: Set[int] = set()
const_nb_shapes: Set[Variable] = set()
for shape in maybe_non_bcast_shapes:
for shape in non_bcast_shapes:
if isinstance(shape, Constant):
const_nb_shapes.add(shape.value.item())
else:
nonconst_nb_shapes.add(shape)
if len(const_nb_shapes) > 1:
raise ValueError("Could not broadcast dimensions")
elif len(const_nb_shapes) == 1:
(const_nb_shape,) = const_nb_shapes
assert const_nb_shape != 1
const_nt_shape_var = pytensor.scalar.ScalarConstant(
pytensor.scalar.int64, const_nb_shape
)
if len(nonconst_nb_shapes) > 0:
# All the potential non-broadcast shapes need to either
# be broadcastable or equal to the one non-broadcastable
# constant `const_nt_shape_var`.
assert_dim = Assert("Could not broadcast dimensions")
scalar_nonconst_nb_shapes = [
at.scalar_from_tensor(s)
if isinstance(s.type, TensorType)
else s
for s in nonconst_nb_shapes
]
dummy_nonconst_nb_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_nonconst_nb_shapes
]
assert_cond = reduce(
aes.and_,
(
aes.or_(
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
)
for nbv in dummy_nonconst_nb_shapes
),
raise ValueError(
f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}."
)
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond])
bcast_dim = assert_dim(
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes)
)
else:
bcast_dim = const_nt_shape_var
if len(const_nb_shapes) == 1:
(first_length,) = const_nb_shapes
other_lengths = nonconst_nb_shapes
first_length = aes.as_scalar(first_length)
else:
# There are no constant, non-broadcastable shapes in this
# dimension.
all_dims_equal = all(
# TODO FIXME: This is a largely deficient, and expensive, means
# of comparing graphs (and especially shapes)
equal_computations([maybe_non_bcast_shapes[0]], [dim])
for dim in maybe_non_bcast_shapes[1:]
)
first_length, *other_lengths = nonconst_nb_shapes
if all_dims_equal:
result_dims.append(maybe_non_bcast_shapes[0])
if len(other_lengths) == 0:
result_dims.append(first_length)
continue
scalar_maybe_non_bcast_shapes = [
at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s
for s in maybe_non_bcast_shapes
]
dummy_maybe_non_bcast_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_maybe_non_bcast_shapes
]
non_bcast_vec = [
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
for nbv in dummy_maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])
dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)
assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce(
aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
for nbv in non_bcast_vec
),
)
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])
bcast_dim = assert_dim(
dim_max_op(*scalar_maybe_non_bcast_shapes),
assert_cond_op(*scalar_maybe_non_bcast_shapes),
)
result_dims.append(bcast_dim)
# Add assert that all remaining shapes are equal
condition = pt_all([pt_eq(first_length, other) for other in other_lengths])
result_dims.append(_broadcast_assert(first_length, condition))
return tuple(result_dims)
......
......@@ -1703,8 +1703,12 @@ class TestLocalElemwiseAlloc:
],
)
def test_basic(self, expr, x_shape, y_shape):
x = at.tensor(dtype="int64", shape=(None,) * len(x_shape), name="x")
y = at.tensor(dtype="int64", shape=(None,) * len(y_shape), name="y")
x = at.tensor(
dtype="int64", shape=(1 if val == 1 else None for val in x_shape), name="x"
)
y = at.tensor(
dtype="int64", shape=(1 if val == 1 else None for val in y_shape), name="y"
)
z = expr(x, y)
z_opt = pytensor.function(
......@@ -1878,7 +1882,8 @@ class TestLocalElemwiseAlloc:
mode=self.fast_run_mode,
)
self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 1, Assert)
# The second assert is from the shape check...
self.verify_op_count(func, 2, Assert)
def test_misc(self):
x = row(dtype=self.dtype)
......
......@@ -608,9 +608,10 @@ class TestAlgebraicCanonizer:
((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"),
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
# must broadcast as their is a dimshuffle in the computation
((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"),
# The broadcast leads to an extra elemwise to check compatibility
((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"),
((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float32, row)>), Alloc]
]
):
......@@ -621,10 +622,13 @@ class TestAlgebraicCanonizer:
elem = [t for t in topo if isinstance(t.op, Elemwise)]
assert len(elem) == nb_elemwise
assert isinstance(elem[0].op, (Elemwise,))
assert isinstance(
elem[0].op.scalar_op,
assert any(
isinstance(
el.op.scalar_op,
(aes.basic.Reciprocal, aes.basic.TrueDiv),
)
for el in elem
)
assert out_dtype == out.dtype
# test (a / b) * (b / c) * (c / d) -> a / d
......
......@@ -1086,6 +1086,8 @@ def test_broadcast_shape_basic():
assert any(
isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at)
)
# This should fail because it would need dynamic broadcasting
with pytest.raises(AssertionError):
assert np.array_equal([z.eval() for z in b_at], b.shape)
b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_at], b.shape)
......@@ -1183,8 +1185,8 @@ def test_broadcast_shape_constants():
@pytest.mark.parametrize(
("s1_vals", "s2_vals", "exp_res"),
[
((2, 2), (1, 2), (2, 2)),
((0, 2), (1, 2), (0, 2)),
((2, 2), (1, 2), AssertionError),
((0, 2), (1, 2), AssertionError),
((1, 2, 1), (2, 1, 2, 1), (2, 1, 2, 1)),
],
)
......@@ -1203,6 +1205,10 @@ def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res):
res = broadcast_shape(s1s, s2s, arrays_are_shapes=True)
res = at.as_tensor(res)
if exp_res is AssertionError:
with pytest.raises(AssertionError):
res.eval(eval_point)
else:
assert tuple(res.eval(eval_point)) == exp_res
......@@ -1395,7 +1401,7 @@ class TestBroadcastTo(utt.InferShapeTester):
def test_broadcast_arrays():
x, y = at.dvector(), at.dmatrix()
x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix()
x_bcast, y_bcast = broadcast_arrays(x, y)
py_mode = Mode("py", None)
......
......@@ -255,7 +255,7 @@ class InferShapeTester:
# Check that the Op is removed from the compiled function.
if check_topo:
topo_shape = shapes_function.maker.fgraph.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape)
assert not any(t in outputs for t in topo_shape)
topo_out = outputs_function.maker.fgraph.toposort()
assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论