提交 1c07fecd authored 作者: Reyhane Askari's avatar Reyhane Askari

checking for empty tensors

上级 a506cbb5
...@@ -4004,10 +4004,13 @@ class Join(Op): ...@@ -4004,10 +4004,13 @@ class Join(Op):
def perform(self, node, axis_and_tensors, out_): def perform(self, node, axis_and_tensors, out_):
out, = out_ out, = out_
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:] axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
nonempty_tensors = [0 if tensor.size == 0 else 1 for tensor in tensors] # tailing tensors are all tensors except the first one
num_nonempty = numpy.sum(nonempty_tensors) tailing_tensors_are_empty = numpy.all(
# checking if more than one non-empty tensors are joined. [tensor.shape[axis] == 0 for tensor in axis_and_tensors[2:]])
if num_nonempty > 1: if tailing_tensors_are_empty:
out[0] = tensors[0]
else:
ndim = tensors[0].ndim ndim = tensors[0].ndim
if axis < -ndim: if axis < -ndim:
raise IndexError("Join axis %d out of bounds [0, %d)" % raise IndexError("Join axis %d out of bounds [0, %d)" %
...@@ -4015,10 +4018,6 @@ class Join(Op): ...@@ -4015,10 +4018,6 @@ class Join(Op):
out[0] = theano._asarray(numpy.concatenate(tensors, axis=axis), out[0] = theano._asarray(numpy.concatenate(tensors, axis=axis),
dtype=node.outputs[0].type.dtype) dtype=node.outputs[0].type.dtype)
elif num_nonempty == 1:
nonempty_tensor_index = numpy.argmax(nonempty_tensors)
self.view_map = {0: [nonempty_tensor_index]}
out[0] = tensors[nonempty_tensor_index]
def c_code_cache_version(self): def c_code_cache_version(self):
return return
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论