提交 584112f3 authored 作者: David Warde-Farley's avatar David Warde-Farley

Be a bit more liberal in what we accept for a join operation. No longer force…

Be a bit more liberal in what we accept for a join operation. No longer force all broadcastable flags for non-join axes to be the same, and the output's broadcastable flags are the OR of the input's for each axis.
上级 98a543d9
......@@ -2995,32 +2995,50 @@ class Join(Op):
def _make_node_internal(self, axis, tensors,
as_tensor_variable_args, output_maker):
orig = as_tensor_variable_args
if not all(targs.type.ndim for targs in as_tensor_variable_args):
raise TypeError('Join cannot handle arguments of dimension 0. For joining scalar values, see @stack');
# When the axis may vary, no dimension can be guaranteed to be
# broadcastable.
bcastable = [False] * len(as_tensor_variable_args[0].type.broadcastable)
# When the axis is fixed, the broadcastable dimensions remain, except
# for the axis dimension.
# All concatenated elements must also have the same broadcastable
# dimensions.
orig = as_tensor_variable_args
if isinstance(axis, int):
bcasts = [x.type.broadcastable[0:axis] + \
x.type.broadcastable[axis + 1:] for x in as_tensor_variable_args]
if not all([bcasts[0] == bc for bc in bcasts[1:]]):
raise ValueError('Dimensions other than the given axis must'
' have the same broadcast behavior', tensors)
bcastable[:] = as_tensor_variable_args[0].type.broadcastable
try:
bcastable[axis] = False
except IndexError, e:
raise ValueError('Join argument "axis" is out of range (given input dimensions)')
as_tensor_variable_args = [unbroadcast(x, axis) for x in as_tensor_variable_args]
# Handle single-tensor joins immediately.
if len(as_tensor_variable_args) == 1:
bcastable = list(as_tensor_variable_args[0].type.broadcastable)
else:
as_tensor_variable_args = [unbroadcast(x, *range(x.type.ndim)) for x in as_tensor_variable_args]
# When the axis is fixed, the broadcastable dimensions remain, except
# for the axis dimension.
# All concatenated elements must also have the same broadcastable
# dimensions.
# initialize bcastable all false, and then fill in some trues with
# the loops -- a dimension should be broadcastable if at least one
# of the inputs is broadcastable on that dimension (see
# justification below)
bcastable = [False] * len(as_tensor_variable_args[0].type.broadcastable)
ndim = len(bcastable)
if isinstance(axis, int):
# Basically, broadcastable -> length 1, but the converse does not
# hold. So we permit e.g. T/F/T joins, and if they fail at runtime
# they fail, but if they don't then it means that the argument
# where that broadcastable flag was False had length 1 along this
# dimension, and therefore this dimension should be broadcastable
# for the output.
for x in as_tensor_variable_args:
for current_axis, bflag in enumerate(x.type.broadcastable):
# Not sure if this Op supports/supported/will support
# negative indices, but just to be sure...
if current_axis == axis % ndim:
continue
if bflag:
bcastable[current_axis] = True
try:
bcastable[axis] = False
except IndexError, e:
raise ValueError('Join argument "axis" is out of range (given input dimensions)')
as_tensor_variable_args = [unbroadcast(x, axis) for x in as_tensor_variable_args]
else:
# These unbroadcasts are for the gradient... not sure exactly
# why...
as_tensor_variable_args = [unbroadcast(x, *range(x.type.ndim)) for x in as_tensor_variable_args]
# When the axis may vary, no dimension can be guaranteed to be
# broadcastable.
bcastable = [False] * len(as_tensor_variable_args[0].type.broadcastable)
inputs = [as_tensor_variable(axis)] + as_tensor_variable_args
if inputs[0].type not in int_types:
......
......@@ -1641,6 +1641,70 @@ class T_Join_and_Split(unittest.TestCase):
f = function([x,y], [b,c,a])
assert numpy.allclose(f(4, 5), [5, 9, 4])
def test_broadcastable_flag_assignment_mixed_otheraxes(self):
"""
Test that the broadcastable flags for the output of
a join operation on non-join axes are True if one or
more inputs is broadcastable on that dimension.
"""
a = TensorType(dtype='int8', broadcastable=[0, 0, 1])()
b = TensorType(dtype='int8', broadcastable=[1, 0, 1])()
c = join(1, a, b)
assert c.type.broadcastable[0] and c.type.broadcastable[2]
assert not c.type.broadcastable[1]
def test_broadcastable_flag_assignment_mixed_thisaxes(self):
"""
Test that the broadcastable flag of the join axis
is False when some inputs are broadcastable on that
dimension.
"""
a = TensorType(dtype='int8', broadcastable=[0, 0, 1])()
b = TensorType(dtype='int8', broadcastable=[1, 0, 1])()
c = join(0, a, b)
assert not c.type.broadcastable[0]
def test_broadcastable_flags_all_broadcastable_on_joinaxis(self):
"""
Test that joining together several inputs which are all
broadcastable on the join dimension results in the output
being non-broadcastable on the join dimension.
"""
a = TensorType(dtype='int8', broadcastable=[1, 0, 1])()
b = TensorType(dtype='int8', broadcastable=[1, 0, 1])()
c = join(0, a, b)
assert not c.type.broadcastable[0]
def test_broadcastable_single_input_broadcastable_dimension(self):
"""
Test that all broadcastable flags are preserved by a
single-input join.
"""
a = join(0, TensorType(dtype='int8', broadcastable=[1, 0, 1])())
assert a.type.broadcastable[0]
assert a.type.broadcastable[2]
assert not a.type.broadcastable[1]
def test_broadcastable_flags_many_dims_and_inputs(self):
"""
Test that the right broadcastable flags get set for a join
with many inputs and many input dimensions.
"""
a = TensorType(dtype='int8', broadcastable=[1, 0, 1, 0, 0, 0])()
b = TensorType(dtype='int8', broadcastable=[1, 1, 1, 0, 0, 0])()
c = TensorType(dtype='int8', broadcastable=[1, 0, 0, 0, 0, 0])()
d = TensorType(dtype='int8', broadcastable=[1, 0, 1, 1, 0, 1])()
e = TensorType(dtype='int8', broadcastable=[1, 0, 1, 0, 0, 1])()
f = join(0, a, b, c, d, e)
fb = f.type.broadcastable
assert not fb[0] and fb[1] and fb[2] and fb[3] and not fb[4] and fb[5]
g = join(1, a, b, c, d, e)
gb = g.type.broadcastable
assert gb[0] and not gb[1] and gb[2] and gb[3] and not gb[4] and gb[5]
h = join(4, a, b, c, d, e)
hb = h.type.broadcastable
assert hb[0] and hb[1] and hb[2] and hb[3] and not hb[4] and hb[5]
class test_comparison(unittest.TestCase):
def test_gt(self):
x, y = fvector(), fvector()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论