提交 a506cbb5 authored 作者: Reyhane Askari's avatar Reyhane Askari

temporary commit

上级 b9c3556f
...@@ -3886,7 +3886,12 @@ class Join(Op): ...@@ -3886,7 +3886,12 @@ class Join(Op):
""" """
check_input = False check_input = False
__props__ = () __props__ = ("view",)
def __init__(self, view=False):
self.view = view
if view:
self.view_map = {0: [0]}
def make_node(self, *axis_and_tensors): def make_node(self, *axis_and_tensors):
""" """
...@@ -3999,6 +4004,10 @@ class Join(Op): ...@@ -3999,6 +4004,10 @@ class Join(Op):
def perform(self, node, axis_and_tensors, out_): def perform(self, node, axis_and_tensors, out_):
out, = out_ out, = out_
axis, tensors = axis_and_tensors[0], axis_and_tensors[1:] axis, tensors = axis_and_tensors[0], axis_and_tensors[1:]
nonempty_tensors = [0 if tensor.size == 0 else 1 for tensor in tensors]
num_nonempty = numpy.sum(nonempty_tensors)
# checking if more than one non-empty tensors are joined.
if num_nonempty > 1:
ndim = tensors[0].ndim ndim = tensors[0].ndim
if axis < -ndim: if axis < -ndim:
raise IndexError("Join axis %d out of bounds [0, %d)" % raise IndexError("Join axis %d out of bounds [0, %d)" %
...@@ -4006,11 +4015,16 @@ class Join(Op): ...@@ -4006,11 +4015,16 @@ class Join(Op):
out[0] = theano._asarray(numpy.concatenate(tensors, axis=axis), out[0] = theano._asarray(numpy.concatenate(tensors, axis=axis),
dtype=node.outputs[0].type.dtype) dtype=node.outputs[0].type.dtype)
elif num_nonempty == 1:
nonempty_tensor_index = numpy.argmax(nonempty_tensors)
self.view_map = {0: [nonempty_tensor_index]}
out[0] = tensors[nonempty_tensor_index]
def c_code_cache_version(self): def c_code_cache_version(self):
return
return (3,) return (3,)
def c_code(self, node, name, inputs, outputs, sub): def c_code_(self, node, name, inputs, outputs, sub):
axis, tensors = inputs[0], inputs[1:] axis, tensors = inputs[0], inputs[1:]
input_1 = tensors[0] input_1 = tensors[0]
l = len(tensors) l = len(tensors)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论