提交 5841c30e authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Specialize string representation of Dimshuffle

上级 5710f950
...@@ -215,10 +215,18 @@ class DimShuffle(ExternalCOp): ...@@ -215,10 +215,18 @@ class DimShuffle(ExternalCOp):
return Apply(self, [input], [output]) return Apply(self, [input], [output])
def __str__(self): def __str__(self):
if self.inplace: shuffle = sorted(self.shuffle) != self.shuffle
return "InplaceDimShuffle{%s}" % ",".join(str(x) for x in self.new_order) if self.augment and not (shuffle or self.drop):
else: if len(self.augment) == 1:
return "DimShuffle{%s}" % ",".join(str(x) for x in self.new_order) return f"ExpandDims{{axis={self.augment[0]}}}"
return f"ExpandDims{{axes={self.augment}}}"
if self.drop and not (self.augment or shuffle):
if len(self.drop) == 1:
return f"DropDims{{axis={self.drop[0]}}}"
return f"DropDims{{axes={self.drop}}}"
if shuffle and not (self.augment or self.drop):
return f"Transpose{{axes={self.shuffle}}}"
return f"DimShuffle{{order=[{','.join(map(str, self.new_order))}]}}"
def perform(self, node, inp, out, params): def perform(self, node, inp, out, params):
(res,) = inp (res,) = inp
......
...@@ -37,10 +37,10 @@ def test_debugprint_sitsot(): ...@@ -37,10 +37,10 @@ def test_debugprint_sitsot():
│ │ │ │ │ └─ Subtensor{int64} [id H] │ │ │ │ │ └─ Subtensor{int64} [id H]
│ │ │ │ │ ├─ Shape [id I] │ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J] │ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ InplaceDimShuffle{x,0} [id K] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
│ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id L] │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id L]
│ │ │ │ │ │ ├─ A [id M] │ │ │ │ │ │ ├─ A [id M]
│ │ │ │ │ │ └─ InplaceDimShuffle{x} [id N] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O] │ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
│ │ │ │ │ └─ ScalarConstant{0} [id P] │ │ │ │ │ └─ ScalarConstant{0} [id P]
│ │ │ │ └─ Subtensor{int64} [id Q] │ │ │ │ └─ Subtensor{int64} [id Q]
...@@ -95,10 +95,10 @@ def test_debugprint_sitsot_no_extra_info(): ...@@ -95,10 +95,10 @@ def test_debugprint_sitsot_no_extra_info():
│ │ │ │ │ └─ Subtensor{int64} [id H] │ │ │ │ │ └─ Subtensor{int64} [id H]
│ │ │ │ │ ├─ Shape [id I] │ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J] │ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ InplaceDimShuffle{x,0} [id K] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
│ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id L] │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id L]
│ │ │ │ │ │ ├─ A [id M] │ │ │ │ │ │ ├─ A [id M]
│ │ │ │ │ │ └─ InplaceDimShuffle{x} [id N] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O] │ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
│ │ │ │ │ └─ ScalarConstant{0} [id P] │ │ │ │ │ └─ ScalarConstant{0} [id P]
│ │ │ │ └─ Subtensor{int64} [id Q] │ │ │ │ └─ Subtensor{int64} [id Q]
...@@ -264,7 +264,7 @@ def test_debugprint_nested_scans(): ...@@ -264,7 +264,7 @@ def test_debugprint_nested_scans():
for{cpu,scan_fn} [id B] for{cpu,scan_fn} [id B]
← Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0) ← Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0)
├─ InplaceDimShuffle{x} [id Z] ├─ ExpandDims{axis=0} [id Z]
│ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0) │ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
└─ Elemwise{pow,no_inplace} [id BB] └─ Elemwise{pow,no_inplace} [id BB]
├─ Subtensor{int64} [id BC] ├─ Subtensor{int64} [id BC]
...@@ -278,10 +278,10 @@ def test_debugprint_nested_scans(): ...@@ -278,10 +278,10 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ └─ Subtensor{int64} [id BJ] │ │ │ │ │ │ └─ Subtensor{int64} [id BJ]
│ │ │ │ │ │ ├─ Shape [id BK] │ │ │ │ │ │ ├─ Shape [id BK]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL] │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
│ │ │ │ │ │ │ └─ InplaceDimShuffle{x,0} [id BM] │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM]
│ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id BN] │ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id BN]
│ │ │ │ │ │ │ ├─ *2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) │ │ │ │ │ │ │ ├─ *2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0)
│ │ │ │ │ │ │ └─ InplaceDimShuffle{x} [id BP] │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BQ] │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BQ]
│ │ │ │ │ │ └─ ScalarConstant{0} [id BR] │ │ │ │ │ │ └─ ScalarConstant{0} [id BR]
│ │ │ │ │ └─ Subtensor{int64} [id BS] │ │ │ │ │ └─ Subtensor{int64} [id BS]
...@@ -297,7 +297,7 @@ def test_debugprint_nested_scans(): ...@@ -297,7 +297,7 @@ def test_debugprint_nested_scans():
│ │ │ └─ *2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0) │ │ │ └─ *2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ ScalarConstant{1} [id BW] │ │ └─ ScalarConstant{1} [id BW]
│ └─ ScalarConstant{-1} [id BX] │ └─ ScalarConstant{-1} [id BX]
└─ InplaceDimShuffle{x} [id BY] └─ ExpandDims{axis=0} [id BY]
└─ *1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1) └─ *1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
for{cpu,scan_fn} [id BE] for{cpu,scan_fn} [id BE]
...@@ -361,7 +361,7 @@ def test_debugprint_nested_scans(): ...@@ -361,7 +361,7 @@ def test_debugprint_nested_scans():
→ *2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0) → *2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
→ *3-<TensorType(int32, ())> [id BB] -> [id B] (inner_in_non_seqs-1) → *3-<TensorType(int32, ())> [id BB] -> [id B] (inner_in_non_seqs-1)
← Elemwise{mul,no_inplace} [id BC] (inner_out_nit_sot-0) ← Elemwise{mul,no_inplace} [id BC] (inner_out_nit_sot-0)
├─ InplaceDimShuffle{x} [id BD] ├─ ExpandDims{axis=0} [id BD]
│ └─ *0-<TensorType(float64, ())> [id Y] (inner_in_seqs-0) │ └─ *0-<TensorType(float64, ())> [id Y] (inner_in_seqs-0)
└─ Elemwise{pow,no_inplace} [id BE] └─ Elemwise{pow,no_inplace} [id BE]
├─ Subtensor{int64} [id BF] ├─ Subtensor{int64} [id BF]
...@@ -375,10 +375,10 @@ def test_debugprint_nested_scans(): ...@@ -375,10 +375,10 @@ def test_debugprint_nested_scans():
│ │ │ │ │ │ └─ Subtensor{int64} [id BL] │ │ │ │ │ │ └─ Subtensor{int64} [id BL]
│ │ │ │ │ │ ├─ Shape [id BM] │ │ │ │ │ │ ├─ Shape [id BM]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN] │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
│ │ │ │ │ │ │ └─ InplaceDimShuffle{x,0} [id BO] │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
│ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id BP] │ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id BP]
│ │ │ │ │ │ │ ├─ *2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0) │ │ │ │ │ │ │ ├─ *2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0)
│ │ │ │ │ │ │ └─ InplaceDimShuffle{x} [id BQ] │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ]
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BR] │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BR]
│ │ │ │ │ │ └─ ScalarConstant{0} [id BS] │ │ │ │ │ │ └─ ScalarConstant{0} [id BS]
│ │ │ │ │ └─ Subtensor{int64} [id BT] │ │ │ │ │ └─ Subtensor{int64} [id BT]
...@@ -394,7 +394,7 @@ def test_debugprint_nested_scans(): ...@@ -394,7 +394,7 @@ def test_debugprint_nested_scans():
│ │ │ └─ *2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0) │ │ │ └─ *2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
│ │ └─ ScalarConstant{1} [id BX] │ │ └─ ScalarConstant{1} [id BX]
│ └─ ScalarConstant{-1} [id BY] │ └─ ScalarConstant{-1} [id BY]
└─ InplaceDimShuffle{x} [id BZ] └─ ExpandDims{axis=0} [id BZ]
└─ *1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1) └─ *1-<TensorType(int64, ())> [id Z] (inner_in_seqs-1)
for{cpu,scan_fn} [id BH] for{cpu,scan_fn} [id BH]
...@@ -515,10 +515,10 @@ def test_debugprint_mitmot(): ...@@ -515,10 +515,10 @@ def test_debugprint_mitmot():
│ │ │ │ │ │ │ └─ Subtensor{int64} [id K] │ │ │ │ │ │ │ └─ Subtensor{int64} [id K]
│ │ │ │ │ │ │ ├─ Shape [id L] │ │ │ │ │ │ │ ├─ Shape [id L]
│ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M] │ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
│ │ │ │ │ │ │ │ └─ InplaceDimShuffle{x,0} [id N] │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id O] │ │ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id O]
│ │ │ │ │ │ │ │ ├─ A [id P] │ │ │ │ │ │ │ │ ├─ A [id P]
│ │ │ │ │ │ │ │ └─ InplaceDimShuffle{x} [id Q] │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q]
│ │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id R] │ │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id R]
│ │ │ │ │ │ │ └─ ScalarConstant{0} [id S] │ │ │ │ │ │ │ └─ ScalarConstant{0} [id S]
│ │ │ │ │ │ └─ Subtensor{int64} [id T] │ │ │ │ │ │ └─ Subtensor{int64} [id T]
...@@ -559,7 +559,7 @@ def test_debugprint_mitmot(): ...@@ -559,7 +559,7 @@ def test_debugprint_mitmot():
│ │ │ ├─ Elemwise{second,no_inplace} [id BN] │ │ │ ├─ Elemwise{second,no_inplace} [id BN]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ InplaceDimShuffle{x,x} [id BO] │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
│ │ │ │ └─ TensorConstant{0.0} [id BP] │ │ │ │ └─ TensorConstant{0.0} [id BP]
│ │ │ ├─ IncSubtensor{Inc;int64} [id BQ] │ │ │ ├─ IncSubtensor{Inc;int64} [id BQ]
│ │ │ │ ├─ Elemwise{second,no_inplace} [id BR] │ │ │ │ ├─ Elemwise{second,no_inplace} [id BR]
...@@ -567,14 +567,14 @@ def test_debugprint_mitmot(): ...@@ -567,14 +567,14 @@ def test_debugprint_mitmot():
│ │ │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) │ │ │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ScalarConstant{1} [id BT] │ │ │ │ │ │ └─ ScalarConstant{1} [id BT]
│ │ │ │ │ └─ InplaceDimShuffle{x,x} [id BU] │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
│ │ │ │ │ └─ TensorConstant{0.0} [id BV] │ │ │ │ │ └─ TensorConstant{0.0} [id BV]
│ │ │ │ ├─ Elemwise{second} [id BW] │ │ │ │ ├─ Elemwise{second} [id BW]
│ │ │ │ │ ├─ Subtensor{int64} [id BX] │ │ │ │ │ ├─ Subtensor{int64} [id BX]
│ │ │ │ │ │ ├─ Subtensor{int64::} [id BS] │ │ │ │ │ │ ├─ Subtensor{int64::} [id BS]
│ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ScalarConstant{-1} [id BY] │ │ │ │ │ │ └─ ScalarConstant{-1} [id BY]
│ │ │ │ │ └─ InplaceDimShuffle{x} [id BZ] │ │ │ │ │ └─ ExpandDims{axis=0} [id BZ]
│ │ │ │ │ └─ Elemwise{second,no_inplace} [id CA] │ │ │ │ │ └─ Elemwise{second,no_inplace} [id CA]
│ │ │ │ │ ├─ Sum{acc_dtype=float64} [id CB] │ │ │ │ │ ├─ Sum{acc_dtype=float64} [id CB]
│ │ │ │ │ │ └─ Subtensor{int64} [id BX] │ │ │ │ │ │ └─ Subtensor{int64} [id BX]
......
...@@ -275,7 +275,7 @@ def test_debugprint(): ...@@ -275,7 +275,7 @@ def test_debugprint():
exp_res = dedent( exp_res = dedent(
r""" r"""
Elemwise{Composite{(i2 + (i0 - i1))}} 4 Elemwise{Composite{(i2 + (i0 - i1))}} 4
├─ InplaceDimShuffle{x,0} v={0: [0]} 3 ├─ ExpandDims{axis=0} v={0: [0]} 3
│ └─ CGemv{inplace} d={0: [0]} 2 │ └─ CGemv{inplace} d={0: [0]} 2
│ ├─ AllocEmpty{dtype='float64'} 1 │ ├─ AllocEmpty{dtype='float64'} 1
│ │ └─ Shape_i{0} 0 │ │ └─ Shape_i{0} 0
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论