提交 fad99811 authored 作者: Reyhane Askari's avatar Reyhane Askari

minor changes in test_join_oneInput and comments of def join

上级 00a374a7
...@@ -4170,6 +4170,10 @@ def join(axis, *tensors_list): ...@@ -4170,6 +4170,10 @@ def join(axis, *tensors_list):
""" """
Convenience function to concatenate `TensorType`s along the given axis. Convenience function to concatenate `TensorType`s along the given axis.
This function will not add the op in the graph when it is not useful.
For example, in the case that the list of tensors to be concatenated
is one, it will just return the tensor.
Parameters Parameters
---------- ----------
tensors : list of tensors (or list-like) tensors : list of tensors (or list-like)
......
...@@ -4386,13 +4386,9 @@ def test_join_oneInput(): ...@@ -4386,13 +4386,9 @@ def test_join_oneInput():
join_0 = theano.tensor.concatenate([x_0], axis=1) join_0 = theano.tensor.concatenate([x_0], axis=1)
join_1 = theano.tensor.concatenate([x_0, x_1, theano.tensor.shape_padright(x_2)], join_1 = theano.tensor.concatenate([x_0, x_1, theano.tensor.shape_padright(x_2)],
axis=1) axis=1)
f = theano.gof.FunctionGraph([x_0], [join_0])
g = theano.gof.FunctionGraph([x_0, x_1, x_2], [join_1])
assert isinstance(g.toposort()[1].op, Join) assert join_0 is x_0
assert not f.toposort() assert join_1 is not x_0
assert join_0.ndim is 2
assert join_1.ndim is 2
class test_comparison(unittest.TestCase): class test_comparison(unittest.TestCase):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论