提交 1c07fecd authored 作者: Reyhane Askari's avatar Reyhane Askari

checking for empty tensors

上级 a506cbb5
......@@ -4004,10 +4004,13 @@ class Join(Op):
def perform(self, node, axis_and_tensors, out_):
out, = out_
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
nonempty_tensors = [0 if tensor.size == 0 else 1 for tensor in tensors]
num_nonempty = numpy.sum(nonempty_tensors)
# checking if more than one non-empty tensors are joined.
if num_nonempty > 1:
# tailing tensors are all tensors except the first one
tailing_tensors_are_empty = numpy.all(
[tensor.shape[axis] == 0 for tensor in axis_and_tensors[2:]])
if tailing_tensors_are_empty:
out[0] = tensors[0]
else:
ndim = tensors[0].ndim
if axis < -ndim:
raise IndexError("Join axis %d out of bounds [0, %d)" %
......@@ -4015,10 +4018,6 @@ class Join(Op):
out[0] = theano._asarray(numpy.concatenate(tensors, axis=axis),
dtype=node.outputs[0].type.dtype)
elif num_nonempty == 1:
nonempty_tensor_index = numpy.argmax(nonempty_tensors)
self.view_map = {0: [nonempty_tensor_index]}
out[0] = tensors[nonempty_tensor_index]
def c_code_cache_version(self):
return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论