Unverified 提交 602eb04c authored 作者: Jesse Grabowski's avatar Jesse Grabowski 提交者: GitHub

Move shared HStack and VStack methods to Stack class (#1662)

上级 f772066a
......@@ -2794,7 +2794,7 @@ le = __ComparisonSwitch(less_equal_s_s, less_equal_s_d, greater_equal_s_d)
ge = __ComparisonSwitch(greater_equal_s_s, greater_equal_s_d, less_equal_s_d)
class HStack(Op):
class Stack(Op):
__props__ = ("format", "dtype")
def __init__(self, format=None, dtype=None):
......@@ -2819,6 +2819,11 @@ class HStack(Op):
self, var, [SparseTensorType(dtype=self.dtype, format=self.format)()]
)
def __str__(self):
return f"{self.__class__.__name__}({self.format},{self.dtype})"
class HStack(Stack):
def perform(self, node, block, outputs):
(out,) = outputs
for b in block:
......@@ -2853,15 +2858,9 @@ class HStack(Op):
return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)]
def infer_shape(self, fgraph, node, ins_shapes):
def _get(l):
return l[1]
d = sum(map(_get, ins_shapes))
d = sum(shape[1] for shape in ins_shapes)
return [(ins_shapes[0][0], d)]
def __str__(self):
return f"{self.__class__.__name__}({self.format},{self.dtype})"
def hstack(blocks, format=None, dtype=None):
"""
......@@ -2897,7 +2896,7 @@ def hstack(blocks, format=None, dtype=None):
return HStack(format=format, dtype=dtype)(*blocks)
class VStack(HStack):
class VStack(Stack):
def perform(self, node, block, outputs):
(out,) = outputs
for b in block:
......@@ -2932,10 +2931,7 @@ class VStack(HStack):
return [choose(c, d) for c, d in zip(is_continuous, derivative, strict=True)]
def infer_shape(self, fgraph, node, ins_shapes):
def _get(l):
return l[0]
d = sum(map(_get, ins_shapes))
d = sum(shape[0] for shape in ins_shapes)
return [(d, ins_shapes[0][1])]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论