提交 fe8804fa authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Cache sub-type of DimShuffle

上级 b5a64c77
...@@ -166,15 +166,20 @@ class DimShuffle(ExternalCOp): ...@@ -166,15 +166,20 @@ class DimShuffle(ExternalCOp):
self.transposition = self.shuffle + drop self.transposition = self.shuffle + drop
# List of dimensions of the output that are broadcastable and were not # List of dimensions of the output that are broadcastable and were not
# in the original input # in the original input
self.augment = sorted(i for i, x in enumerate(new_order) if x == "x") self.augment = augment = sorted(i for i, x in enumerate(new_order) if x == "x")
self.drop = drop self.drop = drop
self.is_left_expand_dims = self.augment and ( dims_are_shuffled = sorted(self.shuffle) != self.shuffle
self.is_transpose = dims_are_shuffled and not augment and not drop
self.is_squeeze = drop and not dims_are_shuffled and not augment
self.is_expand_dims = augment and not dims_are_shuffled and not drop
self.is_left_expand_dims = self.is_expand_dims and (
input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim)) input_ndim == 0 or new_order[-input_ndim:] == list(range(input_ndim))
) )
self.is_right_expand_dims = self.augment and new_order[:input_ndim] == list( self.is_right_expand_dims = self.is_expand_dims and new_order[
range(input_ndim) :input_ndim
) ] == list(range(input_ndim))
if self.inplace: if self.inplace:
self.view_map = {0: [0]} self.view_map = {0: [0]}
...@@ -215,16 +220,15 @@ class DimShuffle(ExternalCOp): ...@@ -215,16 +220,15 @@ class DimShuffle(ExternalCOp):
return Apply(self, [input], [output]) return Apply(self, [input], [output])
def __str__(self): def __str__(self):
shuffle = sorted(self.shuffle) != self.shuffle if self.is_expand_dims:
if self.augment and not (shuffle or self.drop):
if len(self.augment) == 1: if len(self.augment) == 1:
return f"ExpandDims{{axis={self.augment[0]}}}" return f"ExpandDims{{axis={self.augment[0]}}}"
return f"ExpandDims{{axes={self.augment}}}" return f"ExpandDims{{axes={self.augment}}}"
if self.drop and not (self.augment or shuffle): if self.is_squeeze:
if len(self.drop) == 1: if len(self.drop) == 1:
return f"DropDims{{axis={self.drop[0]}}}" return f"Squeeze{{axis={self.drop[0]}}}"
return f"DropDims{{axes={self.drop}}}" return f"Squeeze{{axes={self.drop}}}"
if shuffle and not (self.augment or self.drop): if self.is_transpose:
return f"Transpose{{axes={self.shuffle}}}" return f"Transpose{{axes={self.shuffle}}}"
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}" return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论