提交 fe8804fa authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cache sub-type of DimShuffle

上级 b5a64c77
......@@ -166,15 +166,20 @@ class DimShuffle(ExternalCOp):
self.transposition = self.shuffle + drop
# List of dimensions of the output that are broadcastable and were not
# in the original input
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.drop = drop
self.is_left_expand_dims = self.augment and (
dims_are_shuffled = sorted(self.shuffle) != self.shuffle
self.is_transpose = dims_are_shuffled and not augment and not drop
self.is_squeeze = drop and not dims_are_shuffled and not augment
self.is_expand_dims = augment and not dims_are_shuffled and not drop
self.is_left_expand_dims = self.is_expand_dims and (
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
)
self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list(
range(input_ndim)
)
self.is_right_expand_dims = self.is_expand_dims and new_order[
:input_ndim
] == list(range(input_ndim))
if self.inplace:
self.view_map = {0: [0]}
......@@ -215,16 +220,15 @@ class DimShuffle(ExternalCOp):
return Apply(self, [input], [output])
def __str__(self):
shuffle = sorted(self.shuffle) != self.shuffle
if self.augment and not (shuffle or self.drop):
if self.is_expand_dims:
if len(self.augment) == 1:
return f"ExpandDims{{axis={self.augment[0]}}}"
return f"ExpandDims{{axes={self.augment}}}"
if self.drop and not (self.augment or shuffle):
if self.is_squeeze:
if len(self.drop) == 1:
return f"DropDims{{axis={self.drop[0]}}}"
return f"DropDims{{axes={self.drop}}}"
if shuffle and not (self.augment or self.drop):
return f"Squeeze{{axis={self.drop[0]}}}"
return f"Squeeze{{axes={self.drop}}}"
if self.is_transpose:
return f"Transpose{{axes={self.shuffle}}}"
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论