提交 b210efbc authored 作者: Brandon T. Willard's avatar Brandon T. Willard

Add a broadcast_shape function

上级 adfeaaf4
...@@ -31,6 +31,7 @@ from theano.tensor.extra_ops import ( ...@@ -31,6 +31,7 @@ from theano.tensor.extra_ops import (
UnravelIndex, UnravelIndex,
ravel_multi_index, ravel_multi_index,
RavelMultiIndex, RavelMultiIndex,
broadcast_shape,
) )
from theano import tensor as tt from theano import tensor as tt
from theano import config, function from theano import config, function
...@@ -1189,3 +1190,121 @@ class TestRavelMultiIndex(utt.InferShapeTester): ...@@ -1189,3 +1190,121 @@ class TestRavelMultiIndex(utt.InferShapeTester):
# dims must be a 1D sequence # dims must be a 1D sequence
with pytest.raises(TypeError): with pytest.raises(TypeError):
ravel_multi_index(((3, 4),), ((3, 4),)) ravel_multi_index(((3, 4),), ((3, 4),))
def test_broadcast_shape():
def shape_tuple(x, use_bcast=True):
if use_bcast:
return tuple(
s if not bcast else 1
for s, bcast in zip(tuple(x.shape), x.broadcastable)
)
else:
return tuple(s for s in tuple(x.shape))
x = np.array([[1], [2], [3]])
y = np.array([4, 5, 6])
b = np.broadcast(x, y)
x_tt = tt.as_tensor_variable(x)
y_tt = tt.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
# Now, we try again using shapes as the inputs
#
# This case also confirms that a broadcast dimension will
# broadcast against a non-broadcast dimension when they're
# both symbolic (i.e. we couldn't obtain constant values).
b_tt = broadcast_shape(
shape_tuple(x_tt, use_bcast=False),
shape_tuple(y_tt, use_bcast=False),
arrays_are_shapes=True,
)
assert any(
isinstance(node.op, tt.opt.Assert)
for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
# These are all constants, so there shouldn't be any asserts in the
# resulting graph.
assert not any(
isinstance(node.op, tt.opt.Assert)
for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
)
x = np.array([1, 2, 3])
y = np.array([4, 5, 6])
b = np.broadcast(x, y)
x_tt = tt.as_tensor_variable(x)
y_tt = tt.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, tt.opt.Assert)
# for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
# )
x = np.empty((1, 2, 3))
y = np.array(1)
b = np.broadcast(x, y)
x_tt = tt.as_tensor_variable(x)
y_tt = tt.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt)
assert b_tt[0].value == 1
assert np.array_equal([z.eval() for z in b_tt], b.shape)
assert not any(
isinstance(node.op, tt.opt.Assert)
for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
)
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
x = np.empty((2, 1, 3))
y = np.empty((2, 1, 1))
b = np.broadcast(x, y)
x_tt = tt.as_tensor_variable(x)
y_tt = tt.as_tensor_variable(y)
b_tt = broadcast_shape(x_tt, y_tt)
assert b_tt[1].value == 1
assert np.array_equal([z.eval() for z in b_tt], b.shape)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, tt.opt.Assert)
# for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
# )
b_tt = broadcast_shape(shape_tuple(x_tt), shape_tuple(y_tt), arrays_are_shapes=True)
assert np.array_equal([z.eval() for z in b_tt], b.shape)
x1_shp_tt = tt.iscalar("x1")
x2_shp_tt = tt.iscalar("x2")
y1_shp_tt = tt.iscalar("y1")
x_shapes = (1, x1_shp_tt, x2_shp_tt)
x_tt = tt.ones(x_shapes)
y_shapes = (y1_shp_tt, 1, x2_shp_tt)
y_tt = tt.ones(y_shapes)
b_tt = broadcast_shape(x_tt, y_tt)
# TODO: This will work when/if we use a more sophisticated `is_same_graph`
# implementation.
# assert not any(
# isinstance(node.op, tt.opt.Assert)
# for node in tt.gof.graph.ops([x_tt, y_tt], b_tt)
# )
res = tt.as_tensor(b_tt).eval(
{
x1_shp_tt: 10,
x2_shp_tt: 4,
y1_shp_tt: 2,
}
)
assert np.array_equal(res, (2, 10, 4))
y_shapes = (y1_shp_tt, 1, y1_shp_tt)
y_tt = tt.ones(y_shapes)
b_tt = broadcast_shape(x_tt, y_tt)
assert isinstance(b_tt[-1].owner.op, tt.opt.Assert)
...@@ -1455,3 +1455,80 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"): ...@@ -1455,3 +1455,80 @@ def ravel_multi_index(multi_index, dims, mode="raise", order="C"):
raise TypeError("multi_index must be a tuple or a list.") raise TypeError("multi_index must be a tuple or a list.")
args = tuple(multi_index) + (dims,) args = tuple(multi_index) + (dims,)
return RavelMultiIndex(mode=mode, order=order)(*args) return RavelMultiIndex(mode=mode, order=order)(*args)
def broadcast_shape(*arrays, **kwargs):
"""Compute the shape resulting from broadcasting arrays.
Parameters
----------
*arrays: Tuple[TensorVariable] or Tuple[Tuple[Variable]]
A tuple of tensors, or a tuple of shapes (as tuples),
for which the broadcast shape is computed.
arrays_are_shapes: bool (Optional)
Indicates whether or not the `arrays` contains shape tuples.
If you use this approach, make sure that the broadcastable dimensions
are (scalar) constants with the value `1` or `1` exactly.
"""
one = theano.scalar.ScalarConstant(theano.scalar.int64, 1)
arrays_are_shapes = kwargs.pop("arrays_are_shapes", False)
if arrays_are_shapes:
max_dims = max(len(a) for a in arrays)
array_shapes = [
(one,) * (max_dims - len(a))
+ tuple(one if getattr(sh, "value", sh) == 1 else sh for sh in a)
for a in arrays
]
else:
max_dims = max(a.ndim for a in arrays)
array_shapes = [
(one,) * (max_dims - a.ndim)
+ tuple(one if bcast else sh for sh, bcast in zip(a.shape, a.broadcastable))
for a in arrays
]
result_dims = []
for dim_shapes in zip(*array_shapes):
non_bcast_shapes = [shape for shape in dim_shapes if shape != one]
if len(non_bcast_shapes) > 0:
# Either there's only one non-broadcastable dimensions--and that's
# what determines the dimension size, or there are multiple
# non-broadcastable dimensions that must be equal
i_dim = non_bcast_shapes.pop()
potentially_unequal_dims = [
dim
for dim in non_bcast_shapes
# TODO FIXME: This is a largely deficient means of comparing graphs
# (and especially shapes)
if not theano.gof.graph.equal_computations([i_dim], [dim])
]
if potentially_unequal_dims:
from theano.tensor.opt import Assert
# In this case, we can't tell whether or not the dimensions are
# equal, so we'll need to assert their equality and move the error
# handling to evaluation time.
assert_dim = Assert("Could not broadcast dimensions")
eq_condition = basic.all(
[
basic.or_(basic.eq(dim, one), basic.eq(i_dim, dim))
for dim in potentially_unequal_dims
]
)
eq_condition = basic.or_(basic.eq(i_dim, one), eq_condition)
result_dims.append(assert_dim(i_dim, eq_condition))
else:
result_dims.append(i_dim)
else:
# Every array was broadcastable in this dimension
result_dims.append(one)
return tuple(result_dims)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论