提交 4e73c5ea authored 作者: Reyhane Askari's avatar Reyhane Askari

added join function to avoid concat with only one input

上级 c0b24762
...@@ -4162,7 +4162,11 @@ class Join(Op): ...@@ -4162,7 +4162,11 @@ class Join(Op):
return [tuple(out_shapes)] return [tuple(out_shapes)]
""" join_ = Join()
def join(axis, *tensors_list):
"""
Convenience function to concatenate `TensorType`s along the given axis. Convenience function to concatenate `TensorType`s along the given axis.
Parameters Parameters
...@@ -4181,12 +4185,11 @@ class Join(Op): ...@@ -4181,12 +4185,11 @@ class Join(Op):
former case, the axis is fixed at construction, while in the former case, the axis is fixed at construction, while in the
latter it may vary over time depending on the value of the latter it may vary over time depending on the value of the
`axis` variable. `axis` variable.
"""
""" if len(tensors_list) == 1:
return tensors_list[0]
join = Join() else:
return join_(axis, *tensors_list)
pprint.assign(Join, printing.FunctionPrinter('join'))
def roll(x, shift, axis=None): def roll(x, shift, axis=None):
......
...@@ -4372,6 +4372,29 @@ def test_join_inplace(): ...@@ -4372,6 +4372,29 @@ def test_join_inplace():
assert numpy.allclose(f(data, 0), [3, 4, 5]) assert numpy.allclose(f(data, 0), [3, 4, 5])
def test_join_oneInput():
"""Test join when only 1 input is given.
This functions tests the case when concatenate is called
on an array of tensors but the array has only one element.
In this case, we would like to avoid the computational
overhead of concatenation of one element.
"""
x_0 = theano.tensor.fmatrix()
x_1 = theano.tensor.fmatrix()
x_2 = theano.tensor.fvector()
join_0 = theano.tensor.concatenate([x_0], axis=1)
join_1 = theano.tensor.concatenate([x_0, x_1, theano.tensor.shape_padright(x_2)],
axis=1)
f = theano.gof.FunctionGraph([x_0], [join_0])
g = theano.gof.FunctionGraph([x_0, x_1, x_2], [join_1])
assert isinstance(g.toposort()[1].op, Join)
assert not f.toposort()
assert join_0.ndim is 2
assert join_1.ndim is 2
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
"""Test <, >, <=, >=, == and != """Test <, >, <=, >=, == and !=
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论