提交 2d8ea781 authored 作者: Michael Osthege's avatar Michael Osthege 提交者: Thomas Wiecki

Cleanup `Join.make_node` to infer static shapes

Closes #163
上级 9e4c0e48
...@@ -2217,8 +2217,6 @@ class Join(COp): ...@@ -2217,8 +2217,6 @@ class Join(COp):
# except for the axis dimension. # except for the axis dimension.
# Initialize bcastable all false, and then fill in some trues with # Initialize bcastable all false, and then fill in some trues with
# the loops. # the loops.
ndim = tensors[0].type.ndim
out_shape = [None] * ndim
if not isinstance(axis, int): if not isinstance(axis, int):
try: try:
...@@ -2226,6 +2224,7 @@ class Join(COp): ...@@ -2226,6 +2224,7 @@ class Join(COp):
except NotScalarConstantError: except NotScalarConstantError:
pass pass
ndim = tensors[0].type.ndim
if isinstance(axis, int): if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the # Basically, broadcastable -> length 1, but the
# converse does not hold. So we permit e.g. T/F/T # converse does not hold. So we permit e.g. T/F/T
...@@ -2241,30 +2240,55 @@ class Join(COp): ...@@ -2241,30 +2240,55 @@ class Join(COp):
) )
if axis < 0: if axis < 0:
axis += ndim axis += ndim
if axis > ndim - 1:
for x in tensors:
for current_axis, s in enumerate(x.type.shape):
# Constant negative axis can no longer be negative at
# this point. It safe to compare this way.
if current_axis == axis:
continue
if s == 1:
out_shape[current_axis] = 1
try:
out_shape[axis] = None
except IndexError:
raise ValueError( raise ValueError(
f"Axis value {axis} is out of range for the given input dimensions" f"Axis value {axis} is out of range for the given input dimensions"
) )
# NOTE: Constant negative axis can no longer be negative at this point.
in_shapes = [x.type.shape for x in tensors]
in_ndims = [len(s) for s in in_shapes]
if set(in_ndims) != {ndim}:
raise TypeError(
"Only tensors with the same number of dimensions can be joined."
f" Input ndims were: {in_ndims}."
)
# Determine output shapes from a matrix of input shapes
in_shapes = np.array(in_shapes)
out_shape = [None] * ndim
for d in range(ndim):
ins = in_shapes[:, d]
if d == axis:
# Any unknown size along the axis means we can't sum
if None in ins:
out_shape[d] = None
else:
out_shape[d] = sum(ins)
else:
inset = set(in_shapes[:, d])
# Other dims must match exactly,
# or if a mix of None and ? the output will be ?
# otherwise the input shapes are incompatible.
if len(inset) == 1:
(out_shape[d],) = inset
elif len(inset - {None}) == 1:
(out_shape[d],) = inset - {None}
else:
raise ValueError(
f"all input array dimensions other than the specified `axis` ({axis})"
" must match exactly, or be unknown (None),"
f" but along dimension {d}, the inputs shapes are incompatible: {ins}"
)
else: else:
# When the axis may vary, no dimension can be guaranteed to be # When the axis may vary, no dimension can be guaranteed to be
# broadcastable. # broadcastable.
out_shape = [None] * tensors[0].type.ndim out_shape = [None] * tensors[0].type.ndim
if not builtins.all(x.ndim == len(out_shape) for x in tensors): if not builtins.all(x.ndim == len(out_shape) for x in tensors):
raise TypeError( raise TypeError(
"Only tensors with the same number of dimensions can be joined" "Only tensors with the same number of dimensions can be joined"
) )
inputs = [as_tensor_variable(axis)] + list(tensors) inputs = [as_tensor_variable(axis)] + list(tensors)
......
...@@ -1909,6 +1909,21 @@ class TestJoinAndSplit: ...@@ -1909,6 +1909,21 @@ class TestJoinAndSplit:
with pytest.raises(TypeError, match="same number of dimensions"): with pytest.raises(TypeError, match="same number of dimensions"):
self.join_op(0, v, m) self.join_op(0, v, m)
def test_static_shape_inference(self):
a = at.tensor(dtype="int8", shape=(2, 3))
b = at.tensor(dtype="int8", shape=(2, 5))
assert at.join(1, a, b).type.shape == (2, 8)
assert at.join(-1, a, b).type.shape == (2, 8)
# Check early informative errors from static shape info
with pytest.raises(ValueError, match="must match exactly"):
at.join(0, at.ones((2, 3)), at.ones((2, 5)))
# Check partial inference
d = at.tensor(dtype="int8", shape=(2, None))
assert at.join(1, a, b, d).type.shape == (2, None)
return
def test_split_0elem(self): def test_split_0elem(self):
rng = np.random.default_rng(seed=utt.fetch_seed()) rng = np.random.default_rng(seed=utt.fetch_seed())
m = self.shared(rng.random((4, 6)).astype(self.floatX)) m = self.shared(rng.random((4, 6)).astype(self.floatX))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论