提交 6052f90f authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5417 from ReyhaneAskari/4897

make tensor.join() return the input when there is only 1 variable to join.
......@@ -591,7 +591,7 @@ def local_gpua_alloc2(node):
return
if (isinstance(node.op, tensor.Alloc) and
all(c != 'output' and
c.op == tensor.join and
isinstance(c.op, tensor.Join) and
all(i.owner and
i.owner.op in [host_from_gpu, tensor.alloc]
for i in c.inputs[1:])
......
......@@ -378,7 +378,7 @@ def test_gpujoin_gpualloc():
T.ones_like(b)) + 4,
mode=mode_with_gpu)
assert sum([node.op == T.alloc for node in f.maker.fgraph.toposort()]) == 2
assert sum([node.op == T.join for node in f.maker.fgraph.toposort()]) == 1
assert sum([node.op == T.join_ for node in f.maker.fgraph.toposort()]) == 1
assert sum([isinstance(node.op, GpuAlloc)
for node in f_gpu.maker.fgraph.toposort()]) == 2
assert sum([node.op == gpu_join
......
......@@ -2207,7 +2207,7 @@ def local_gpualloc(node):
# if all clients are on gpu
replace = True
elif all([c != 'output' and
c.op == tensor.join and
isinstance(c.op, tensor.Join) and
all(i.owner and
i.owner.op in [host_from_gpu, tensor.alloc]
for i in c.inputs[1:])
......
......@@ -945,7 +945,7 @@ def test_gpujoin_gpualloc():
mode=mode_with_gpu)
assert sum([node.op == T.alloc for node in f.maker.fgraph.toposort()]) == 2
assert sum([node.op == T.join for node in f.maker.fgraph.toposort()]) == 1
assert sum([isinstance(node.op, T.Join) for node in f.maker.fgraph.toposort()]) == 1
assert sum([isinstance(node.op, B.GpuAlloc)
for node in f_gpu.maker.fgraph.toposort()]) == 2
assert sum([node.op == B.gpu_join
......
......@@ -4174,9 +4174,18 @@ class Join(Op):
return [tuple(out_shapes)]
"""
join_ = Join()
pprint.assign(Join, printing.FunctionPrinter('join'))
def join(axis, *tensors_list):
"""
Convenience function to concatenate `TensorType`s along the given axis.
This function will not add the op in the graph when it is not useful.
For example, in the case that the list of tensors to be concatenated
is one, it will just return the tensor.
Parameters
----------
tensors : list of tensors (or list-like)
......@@ -4193,12 +4202,11 @@ class Join(Op):
former case, the axis is fixed at construction, while in the
latter it may vary over time depending on the value of the
`axis` variable.
"""
join = Join()
pprint.assign(Join, printing.FunctionPrinter('join'))
"""
if len(tensors_list) == 1:
return tensors_list[0]
else:
return join_(axis, *tensors_list)
def roll(x, shift, axis=None):
......
......@@ -4372,6 +4372,25 @@ def test_join_inplace():
assert numpy.allclose(f(data, 0), [3, 4, 5])
def test_join_oneInput():
"""Test join when only 1 input is given.
This functions tests the case when concatenate is called
on an array of tensors but the array has only one element.
In this case, we would like to avoid the computational
overhead of concatenation of one element.
"""
x_0 = theano.tensor.fmatrix()
x_1 = theano.tensor.fmatrix()
x_2 = theano.tensor.fvector()
join_0 = theano.tensor.concatenate([x_0], axis=1)
join_1 = theano.tensor.concatenate([x_0, x_1, theano.tensor.shape_padright(x_2)],
axis=1)
assert join_0 is x_0
assert join_1 is not x_0
class test_comparison(unittest.TestCase):
"""Test <, >, <=, >=, == and !=
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论