提交 bd561974 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add missing info to Rebroadcast str representation

上级 ce2c8613
...@@ -8,7 +8,6 @@ manipulation of tensors. ...@@ -8,7 +8,6 @@ manipulation of tensors.
import builtins import builtins
import logging import logging
import warnings import warnings
from collections import OrderedDict
from collections.abc import Sequence from collections.abc import Sequence
from functools import partial from functools import partial
from numbers import Number from numbers import Number
...@@ -697,7 +696,7 @@ class Rebroadcast(COp): ...@@ -697,7 +696,7 @@ class Rebroadcast(COp):
def __init__(self, *axis): def __init__(self, *axis):
# Sort them to make sure we merge all possible case. # Sort them to make sure we merge all possible case.
items = sorted(axis) items = sorted(axis)
self.axis = OrderedDict(items) self.axis = dict(items)
for axis, broad in self.axis.items(): for axis, broad in self.axis.items():
if not isinstance(axis, (np.integer, int)): if not isinstance(axis, (np.integer, int)):
raise TypeError(f"Rebroadcast needs integer axes. Got {axis}") raise TypeError(f"Rebroadcast needs integer axes. Got {axis}")
...@@ -714,13 +713,7 @@ class Rebroadcast(COp): ...@@ -714,13 +713,7 @@ class Rebroadcast(COp):
return hash((type(self), tuple(items))) return hash((type(self), tuple(items)))
def __str__(self): def __str__(self):
if len(self.axis) == 0: return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axis.items())}}}"
broadcast_pattern = []
else:
broadcast_pattern = ["?" for i in range(1 + max(self.axis.keys()))]
for k, v in self.axis.items():
broadcast_pattern[k] = str(int(v))
return f"{self.__class__.__name__}{{{','.join(broadcast_pattern)}}}"
def make_node(self, x): def make_node(self, x):
if self.axis.keys() and (x.ndim <= max(self.axis.keys())): if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
......
...@@ -33,7 +33,7 @@ def test_scan_debugprint1(): ...@@ -33,7 +33,7 @@ def test_scan_debugprint1():
| | | | | |k [id D] | | | | | |k [id D]
| | | | | |Subtensor{int64} [id H] '' | | | | | |Subtensor{int64} [id H] ''
| | | | | |Shape [id I] '' | | | | | |Shape [id I] ''
| | | | | | |Rebroadcast{0} [id J] '' | | | | | | |Rebroadcast{(0, False)} [id J] ''
| | | | | | |InplaceDimShuffle{x,0} [id K] '' | | | | | | |InplaceDimShuffle{x,0} [id K] ''
| | | | | | |Elemwise{second,no_inplace} [id L] '' | | | | | | |Elemwise{second,no_inplace} [id L] ''
| | | | | | |A [id M] | | | | | | |A [id M]
...@@ -42,9 +42,9 @@ def test_scan_debugprint1(): ...@@ -42,9 +42,9 @@ def test_scan_debugprint1():
| | | | | |ScalarConstant{0} [id P] | | | | | |ScalarConstant{0} [id P]
| | | | |Subtensor{int64} [id Q] '' | | | | |Subtensor{int64} [id Q] ''
| | | | |Shape [id R] '' | | | | |Shape [id R] ''
| | | | | |Rebroadcast{0} [id J] '' | | | | | |Rebroadcast{(0, False)} [id J] ''
| | | | |ScalarConstant{1} [id S] | | | | |ScalarConstant{1} [id S]
| | | |Rebroadcast{0} [id J] '' | | | |Rebroadcast{(0, False)} [id J] ''
| | | |ScalarFromTensor [id T] '' | | | |ScalarFromTensor [id T] ''
| | | |Subtensor{int64} [id H] '' | | | |Subtensor{int64} [id H] ''
| | |A [id M] | | |A [id M]
...@@ -208,7 +208,7 @@ def test_scan_debugprint3(): ...@@ -208,7 +208,7 @@ def test_scan_debugprint3():
> | | | | | | |k_copy [id BF] -> [id X] > | | | | | | |k_copy [id BF] -> [id X]
> | | | | | | |Subtensor{int64} [id BJ] '' > | | | | | | |Subtensor{int64} [id BJ] ''
> | | | | | | |Shape [id BK] '' > | | | | | | |Shape [id BK] ''
> | | | | | | | |Rebroadcast{0} [id BL] '' > | | | | | | | |Rebroadcast{(0, False)} [id BL] ''
> | | | | | | | |InplaceDimShuffle{x,0} [id BM] '' > | | | | | | | |InplaceDimShuffle{x,0} [id BM] ''
> | | | | | | | |Elemwise{second,no_inplace} [id BN] '' > | | | | | | | |Elemwise{second,no_inplace} [id BN] ''
> | | | | | | | |A_copy [id BO] -> [id W] > | | | | | | | |A_copy [id BO] -> [id W]
...@@ -217,9 +217,9 @@ def test_scan_debugprint3(): ...@@ -217,9 +217,9 @@ def test_scan_debugprint3():
> | | | | | | |ScalarConstant{0} [id BR] > | | | | | | |ScalarConstant{0} [id BR]
> | | | | | |Subtensor{int64} [id BS] '' > | | | | | |Subtensor{int64} [id BS] ''
> | | | | | |Shape [id BT] '' > | | | | | |Shape [id BT] ''
> | | | | | | |Rebroadcast{0} [id BL] '' > | | | | | | |Rebroadcast{(0, False)} [id BL] ''
> | | | | | |ScalarConstant{1} [id BU] > | | | | | |ScalarConstant{1} [id BU]
> | | | | |Rebroadcast{0} [id BL] '' > | | | | |Rebroadcast{(0, False)} [id BL] ''
> | | | | |ScalarFromTensor [id BV] '' > | | | | |ScalarFromTensor [id BV] ''
> | | | | |Subtensor{int64} [id BJ] '' > | | | | |Subtensor{int64} [id BJ] ''
> | | | |A_copy [id BO] -> [id W] > | | | |A_copy [id BO] -> [id W]
...@@ -341,7 +341,7 @@ def test_scan_debugprint5(): ...@@ -341,7 +341,7 @@ def test_scan_debugprint5():
| | | | | | | |k [id G] | | | | | | | |k [id G]
| | | | | | | |Subtensor{int64} [id K] '' | | | | | | | |Subtensor{int64} [id K] ''
| | | | | | | |Shape [id L] '' | | | | | | | |Shape [id L] ''
| | | | | | | | |Rebroadcast{0} [id M] '' | | | | | | | | |Rebroadcast{(0, False)} [id M] ''
| | | | | | | | |InplaceDimShuffle{x,0} [id N] '' | | | | | | | | |InplaceDimShuffle{x,0} [id N] ''
| | | | | | | | |Elemwise{second,no_inplace} [id O] '' | | | | | | | | |Elemwise{second,no_inplace} [id O] ''
| | | | | | | | |A [id P] | | | | | | | | |A [id P]
...@@ -350,9 +350,9 @@ def test_scan_debugprint5(): ...@@ -350,9 +350,9 @@ def test_scan_debugprint5():
| | | | | | | |ScalarConstant{0} [id S] | | | | | | | |ScalarConstant{0} [id S]
| | | | | | |Subtensor{int64} [id T] '' | | | | | | |Subtensor{int64} [id T] ''
| | | | | | |Shape [id U] '' | | | | | | |Shape [id U] ''
| | | | | | | |Rebroadcast{0} [id M] '' | | | | | | | |Rebroadcast{(0, False)} [id M] ''
| | | | | | |ScalarConstant{1} [id V] | | | | | | |ScalarConstant{1} [id V]
| | | | | |Rebroadcast{0} [id M] '' | | | | | |Rebroadcast{(0, False)} [id M] ''
| | | | | |ScalarFromTensor [id W] '' | | | | | |ScalarFromTensor [id W] ''
| | | | | |Subtensor{int64} [id K] '' | | | | | |Subtensor{int64} [id K] ''
| | | | |A [id P] | | | | |A [id P]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论