提交 5573bd40 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Remove custom __call__ function for tensor.Split Op

上级 cd15ea08
...@@ -3185,14 +3185,6 @@ class Split(Op): ...@@ -3185,14 +3185,6 @@ class Split(Op):
def __hash__(self): def __hash__(self):
return hash(Split) ^ self.len_splits return hash(Split) ^ self.len_splits
def __call__(self, *inputs, **kwargs):
"""Override Op.__call__ to suppress unpacking of output list
"""
node = self.make_node(*inputs, **kwargs)
node.tag.trace = traceback.extract_stack()[:-1]
return node.outputs
def make_node(self, x, axis, splits): def make_node(self, x, axis, splits):
"""WRITEME""" """WRITEME"""
x = as_tensor_variable(x) x = as_tensor_variable(x)
...@@ -3467,6 +3459,9 @@ class Join(Op): ...@@ -3467,6 +3459,9 @@ class Join(Op):
# assume that this is differentiable # assume that this is differentiable
split = Split(len(tensors)) split = Split(len(tensors))
split_gz = split(gz, axis, stack(*[shape(x)[axis] for x in tensors])) split_gz = split(gz, axis, stack(*[shape(x)[axis] for x in tensors]))
# If there is only one split, it might not be in a list.
if not isinstance(split_gz, list):
split_gz = [split_gz]
return [None] + split_gz return [None] + split_gz
else: else:
# assume that this isn't differentiable # assume that this isn't differentiable
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论