Unverified 提交 df4183d5 authored 作者: Adrian Seyboldt's avatar Adrian Seyboldt 提交者: GitHub

Use static-only broadcasting rules to compute shape of broadcasting (#345)

上级 b9c4f20d
from collections.abc import Collection from collections.abc import Collection
from functools import reduce
from typing import Iterable, Set, Tuple, Union from typing import Iterable, Set, Tuple, Union
import numpy as np import numpy as np
import numpy.core.numeric
from numpy.core.multiarray import normalize_axis_index from numpy.core.multiarray import normalize_axis_index
import pytensor import pytensor
...@@ -14,7 +12,7 @@ from pytensor.gradient import ( ...@@ -14,7 +12,7 @@ from pytensor.gradient import (
disconnected_type, disconnected_type,
grad_undefined, grad_undefined,
) )
from pytensor.graph.basic import Apply, Constant, Variable, equal_computations from pytensor.graph.basic import Apply, Constant, Variable
from pytensor.graph.op import Op from pytensor.graph.op import Op
from pytensor.link.c.op import COp from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
...@@ -23,12 +21,12 @@ from pytensor.misc.safe_asarray import _asarray ...@@ -23,12 +21,12 @@ from pytensor.misc.safe_asarray import _asarray
from pytensor.raise_op import Assert from pytensor.raise_op import Assert
from pytensor.scalar import int32 as int_t from pytensor.scalar import int32 as int_t
from pytensor.scalar import upcast from pytensor.scalar import upcast
from pytensor.scalar.basic import Composite
from pytensor.tensor import basic as at from pytensor.tensor import basic as at
from pytensor.tensor import get_vector_length from pytensor.tensor import get_vector_length
from pytensor.tensor.exceptions import NotScalarConstantError from pytensor.tensor.exceptions import NotScalarConstantError
from pytensor.tensor.math import abs as at_abs from pytensor.tensor.math import abs as at_abs
from pytensor.tensor.math import all as at_all from pytensor.tensor.math import all as pt_all
from pytensor.tensor.math import eq as pt_eq
from pytensor.tensor.math import ge, lt, maximum, minimum, prod from pytensor.tensor.math import ge, lt, maximum, minimum, prod
from pytensor.tensor.math import sum as at_sum from pytensor.tensor.math import sum as at_sum
from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor from pytensor.tensor.subtensor import advanced_inc_subtensor1, set_subtensor
...@@ -536,7 +534,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False): ...@@ -536,7 +534,7 @@ def bincount(x, weights=None, minlength=None, assert_nonneg=False):
if assert_nonneg: if assert_nonneg:
assert_op = Assert("Input to bincount has negative values!") assert_op = Assert("Input to bincount has negative values!")
x = assert_op(x, at_all(x >= 0)) x = assert_op(x, pt_all(x >= 0))
max_value = at.cast(x.max() + 1, "int64") max_value = at.cast(x.max() + 1, "int64")
...@@ -1436,6 +1434,13 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"): ...@@ -1436,6 +1434,13 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
return RavelMultiIndex(mode=mode, order=order)(*args) return RavelMultiIndex(mode=mode, order=order)(*args)
_broadcast_assert = Assert(
"Could not broadcast dimensions. Broadcasting is only allowed along "
"axes that have a statically known length 1. Use `specify_shape` to "
"inform PyTensor of a known shape."
)
def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]: def broadcast_shape(*arrays, **kwargs) -> Tuple[aes.ScalarVariable, ...]:
"""Compute the shape resulting from broadcasting arrays. """Compute the shape resulting from broadcasting arrays.
...@@ -1510,119 +1515,45 @@ def broadcast_shape_iter( ...@@ -1510,119 +1515,45 @@ def broadcast_shape_iter(
result_dims = [] result_dims = []
for dim_shapes in zip(*array_shapes): for dim_shapes in zip(*array_shapes):
# Get the shapes in this dimension that are not definitively # Get the shapes in this dimension that are not broadcastable
# broadcastable (i.e. not symbolically known to be broadcastable) # (i.e. not symbolically known to be broadcastable)
maybe_non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at] non_bcast_shapes = [shape for shape in dim_shapes if shape != one_at]
if len(maybe_non_bcast_shapes) == 0: if len(non_bcast_shapes) == 0:
# Every shape was broadcastable in this dimension # Every shape was broadcastable in this dimension
result_dims.append(one_at) result_dims.append(one_at)
elif len(maybe_non_bcast_shapes) == 1: elif len(non_bcast_shapes) == 1:
# Only one shape might not be broadcastable in this dimension # Only one shape might not be broadcastable in this dimension
result_dims.extend(maybe_non_bcast_shapes) result_dims.extend(non_bcast_shapes)
else: else:
# More than one shape might not be broadcastable in this dimension # More than one shape might not be broadcastable in this dimension
nonconst_nb_shapes: Set[int] = set() nonconst_nb_shapes: Set[int] = set()
const_nb_shapes: Set[Variable] = set() const_nb_shapes: Set[Variable] = set()
for shape in maybe_non_bcast_shapes: for shape in non_bcast_shapes:
if isinstance(shape, Constant): if isinstance(shape, Constant):
const_nb_shapes.add(shape.value.item()) const_nb_shapes.add(shape.value.item())
else: else:
nonconst_nb_shapes.add(shape) nonconst_nb_shapes.add(shape)
if len(const_nb_shapes) > 1: if len(const_nb_shapes) > 1:
raise ValueError("Could not broadcast dimensions") raise ValueError(
elif len(const_nb_shapes) == 1: f"Could not broadcast dimensions. Incompatible shapes were {array_shapes}."
(const_nb_shape,) = const_nb_shapes
assert const_nb_shape != 1
const_nt_shape_var = pytensor.scalar.ScalarConstant(
pytensor.scalar.int64, const_nb_shape
) )
if len(nonconst_nb_shapes) > 0: if len(const_nb_shapes) == 1:
# All the potential non-broadcast shapes need to either (first_length,) = const_nb_shapes
# be broadcastable or equal to the one non-broadcastable other_lengths = nonconst_nb_shapes
# constant `const_nt_shape_var`. first_length = aes.as_scalar(first_length)
assert_dim = Assert("Could not broadcast dimensions")
scalar_nonconst_nb_shapes = [
at.scalar_from_tensor(s)
if isinstance(s.type, TensorType)
else s
for s in nonconst_nb_shapes
]
dummy_nonconst_nb_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_nonconst_nb_shapes
]
assert_cond = reduce(
aes.and_,
(
aes.or_(
aes.eq(nbv, one_at), aes.eq(nbv, const_nt_shape_var)
)
for nbv in dummy_nonconst_nb_shapes
),
)
assert_cond_op = Composite(dummy_nonconst_nb_shapes, [assert_cond])
bcast_dim = assert_dim(
const_nt_shape_var, assert_cond_op(*scalar_nonconst_nb_shapes)
)
else:
bcast_dim = const_nt_shape_var
else: else:
# There are no constant, non-broadcastable shapes in this first_length, *other_lengths = nonconst_nb_shapes
# dimension.
all_dims_equal = all(
# TODO FIXME: This is a largely deficient, and expensive, means
# of comparing graphs (and especially shapes)
equal_computations([maybe_non_bcast_shapes[0]], [dim])
for dim in maybe_non_bcast_shapes[1:]
)
if all_dims_equal: if len(other_lengths) == 0:
result_dims.append(maybe_non_bcast_shapes[0]) result_dims.append(first_length)
continue continue
scalar_maybe_non_bcast_shapes = [
at.scalar_from_tensor(s) if isinstance(s.type, TensorType) else s
for s in maybe_non_bcast_shapes
]
dummy_maybe_non_bcast_shapes = [
aes.get_scalar_type(dtype=v.dtype)()
for v in scalar_maybe_non_bcast_shapes
]
non_bcast_vec = [
aes.switch(aes.eq(nbv, 1), -one_at, nbv)
for nbv in dummy_maybe_non_bcast_shapes
]
dim_max = aes.abs(reduce(aes.scalar_maximum, non_bcast_vec))
dim_max_op = Composite(dummy_maybe_non_bcast_shapes, [dim_max])
dummy_dim_max = dim_max_op(*dummy_maybe_non_bcast_shapes)
assert_dim = Assert("Could not broadcast dimensions")
assert_cond = reduce(
aes.and_,
(
aes.or_(aes.eq(nbv, -one_at), aes.eq(nbv, dummy_dim_max))
for nbv in non_bcast_vec
),
)
assert_cond_op = Composite(dummy_maybe_non_bcast_shapes, [assert_cond])
bcast_dim = assert_dim(
dim_max_op(*scalar_maybe_non_bcast_shapes),
assert_cond_op(*scalar_maybe_non_bcast_shapes),
)
result_dims.append(bcast_dim) # Add assert that all remaining shapes are equal
condition = pt_all([pt_eq(first_length, other) for other in other_lengths])
result_dims.append(_broadcast_assert(first_length, condition))
return tuple(result_dims) return tuple(result_dims)
......
...@@ -1703,8 +1703,12 @@ class TestLocalElemwiseAlloc: ...@@ -1703,8 +1703,12 @@ class TestLocalElemwiseAlloc:
], ],
) )
def test_basic(self, expr, x_shape, y_shape): def test_basic(self, expr, x_shape, y_shape):
x = at.tensor(dtype="int64", shape=(None,) * len(x_shape), name="x") x = at.tensor(
y = at.tensor(dtype="int64", shape=(None,) * len(y_shape), name="y") dtype="int64", shape=(1 if val == 1 else None for val in x_shape), name="x"
)
y = at.tensor(
dtype="int64", shape=(1 if val == 1 else None for val in y_shape), name="y"
)
z = expr(x, y) z = expr(x, y)
z_opt = pytensor.function( z_opt = pytensor.function(
...@@ -1878,7 +1882,8 @@ class TestLocalElemwiseAlloc: ...@@ -1878,7 +1882,8 @@ class TestLocalElemwiseAlloc:
mode=self.fast_run_mode, mode=self.fast_run_mode,
) )
self.verify_op_count(func, 0, Alloc) self.verify_op_count(func, 0, Alloc)
self.verify_op_count(func, 1, Assert) # The second assert is from the shape check...
self.verify_op_count(func, 2, Assert)
def test_misc(self): def test_misc(self):
x = row(dtype=self.dtype) x = row(dtype=self.dtype)
......
...@@ -608,9 +608,10 @@ class TestAlgebraicCanonizer: ...@@ -608,9 +608,10 @@ class TestAlgebraicCanonizer:
((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"), ((dv / dy) / dv, [dv, dy], [dvv, dyv], 1, "float64"),
((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"), ((fv / fy) / fv, [fv, fy], [fvv, fyv], 1, "float32"),
# must broadcast as their is a dimshuffle in the computation # must broadcast as their is a dimshuffle in the computation
((dx / dv) / dx, [dx, dv], [dxv, dvv], 1, "float64"), # The broadcast leads to an extra elemwise to check compatibility
((dx / dv) / dx, [dx, dv], [dxv, dvv], 2, "float64"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc] # topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float64, row)>), Alloc]
((fx / fv) / fx, [fx, fv], [fxv, fvv], 1, "float32"), ((fx / fv) / fx, [fx, fv], [fxv, fvv], 2, "float32"),
# topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float32, row)>), Alloc] # topo: [Shape_i, Shape_i, Elemwise{reciprocal,no_inplace}(<TensorType(float32, row)>), Alloc]
] ]
): ):
...@@ -621,9 +622,12 @@ class TestAlgebraicCanonizer: ...@@ -621,9 +622,12 @@ class TestAlgebraicCanonizer:
elem = [t for t in topo if isinstance(t.op, Elemwise)] elem = [t for t in topo if isinstance(t.op, Elemwise)]
assert len(elem) == nb_elemwise assert len(elem) == nb_elemwise
assert isinstance(elem[0].op, (Elemwise,)) assert isinstance(elem[0].op, (Elemwise,))
assert isinstance( assert any(
elem[0].op.scalar_op, isinstance(
(aes.basic.Reciprocal, aes.basic.TrueDiv), el.op.scalar_op,
(aes.basic.Reciprocal, aes.basic.TrueDiv),
)
for el in elem
) )
assert out_dtype == out.dtype assert out_dtype == out.dtype
......
...@@ -1086,7 +1086,9 @@ def test_broadcast_shape_basic(): ...@@ -1086,7 +1086,9 @@ def test_broadcast_shape_basic():
assert any( assert any(
isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at) isinstance(node.op, Assert) for node in applys_between([x_at, y_at], b_at)
) )
assert np.array_equal([z.eval() for z in b_at], b.shape) # This should fail because it would need dynamic broadcasting
with pytest.raises(AssertionError):
assert np.array_equal([z.eval() for z in b_at], b.shape)
b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True) b_at = broadcast_shape(shape_tuple(x_at), shape_tuple(y_at), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_at], b.shape) assert np.array_equal([z.eval() for z in b_at], b.shape)
...@@ -1183,8 +1185,8 @@ def test_broadcast_shape_constants(): ...@@ -1183,8 +1185,8 @@ def test_broadcast_shape_constants():
@pytest.mark.parametrize( @pytest.mark.parametrize(
("s1_vals", "s2_vals", "exp_res"), ("s1_vals", "s2_vals", "exp_res"),
[ [
((2, 2), (1, 2), (2, 2)), ((2, 2), (1, 2), AssertionError),
((0, 2), (1, 2), (0, 2)), ((0, 2), (1, 2), AssertionError),
((1, 2, 1), (2, 1, 2, 1), (2, 1, 2, 1)), ((1, 2, 1), (2, 1, 2, 1), (2, 1, 2, 1)),
], ],
) )
...@@ -1203,7 +1205,11 @@ def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res): ...@@ -1203,7 +1205,11 @@ def test_broadcast_shape_symbolic(s1_vals, s2_vals, exp_res):
res = broadcast_shape(s1s, s2s, arrays_are_shapes=True) res = broadcast_shape(s1s, s2s, arrays_are_shapes=True)
res = at.as_tensor(res) res = at.as_tensor(res)
assert tuple(res.eval(eval_point)) == exp_res if exp_res is AssertionError:
with pytest.raises(AssertionError):
res.eval(eval_point)
else:
assert tuple(res.eval(eval_point)) == exp_res
def test_broadcast_shape_symbolic_one_symbolic(): def test_broadcast_shape_symbolic_one_symbolic():
...@@ -1395,7 +1401,7 @@ class TestBroadcastTo(utt.InferShapeTester): ...@@ -1395,7 +1401,7 @@ class TestBroadcastTo(utt.InferShapeTester):
def test_broadcast_arrays(): def test_broadcast_arrays():
x, y = at.dvector(), at.dmatrix() x, y = at.tensor(shape=(1,), dtype="float64"), at.dmatrix()
x_bcast, y_bcast = broadcast_arrays(x, y) x_bcast, y_bcast = broadcast_arrays(x, y)
py_mode = Mode("py", None) py_mode = Mode("py", None)
......
...@@ -255,7 +255,7 @@ class InferShapeTester: ...@@ -255,7 +255,7 @@ class InferShapeTester:
# Check that the Op is removed from the compiled function. # Check that the Op is removed from the compiled function.
if check_topo: if check_topo:
topo_shape = shapes_function.maker.fgraph.toposort() topo_shape = shapes_function.maker.fgraph.toposort()
assert not any(isinstance(t.op, cls) for t in topo_shape) assert not any(t in outputs for t in topo_shape)
topo_out = outputs_function.maker.fgraph.toposort() topo_out = outputs_function.maker.fgraph.toposort()
assert any(isinstance(t.op, cls) for t in topo_out) assert any(isinstance(t.op, cls) for t in topo_out)
# Check that the shape produced agrees with the actual shape. # Check that the shape produced agrees with the actual shape.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论