提交 bd561974 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add missing info to Rebroadcast str representation

上级 ce2c8613
......@@ -8,7 +8,6 @@ manipulation of tensors.
import builtins
import logging
import warnings
from collections import OrderedDict
from collections.abc import Sequence
from functools import partial
from numbers import Number
......@@ -697,7 +696,7 @@ class Rebroadcast(COp):
def __init__(self, *axis):
# Sort them to make sure we merge all possible case.
items = sorted(axis)
self.axis = OrderedDict(items)
self.axis = dict(items)
for axis, broad in self.axis.items():
if not isinstance(axis, (np.integer, int)):
raise TypeError(f"Rebroadcast needs integer axes. Got {axis}")
......@@ -714,13 +713,7 @@ class Rebroadcast(COp):
return hash((type(self), tuple(items)))
def __str__(self):
if len(self.axis) == 0:
broadcast_pattern = []
else:
broadcast_pattern = ["?" for i in range(1 + max(self.axis.keys()))]
for k, v in self.axis.items():
broadcast_pattern[k] = str(int(v))
return f"{self.__class__.__name__}{{{','.join(broadcast_pattern)}}}"
return f"{self.__class__.__name__}{{{','.join(str(i) for i in self.axis.items())}}}"
def make_node(self, x):
if self.axis.keys() and (x.ndim <= max(self.axis.keys())):
......
......@@ -33,7 +33,7 @@ def test_scan_debugprint1():
| | | | | |k [id D]
| | | | | |Subtensor{int64} [id H] ''
| | | | | |Shape [id I] ''
| | | | | | |Rebroadcast{0} [id J] ''
| | | | | | |Rebroadcast{(0, False)} [id J] ''
| | | | | | |InplaceDimShuffle{x,0} [id K] ''
| | | | | | |Elemwise{second,no_inplace} [id L] ''
| | | | | | |A [id M]
......@@ -42,9 +42,9 @@ def test_scan_debugprint1():
| | | | | |ScalarConstant{0} [id P]
| | | | |Subtensor{int64} [id Q] ''
| | | | |Shape [id R] ''
| | | | | |Rebroadcast{0} [id J] ''
| | | | | |Rebroadcast{(0, False)} [id J] ''
| | | | |ScalarConstant{1} [id S]
| | | |Rebroadcast{0} [id J] ''
| | | |Rebroadcast{(0, False)} [id J] ''
| | | |ScalarFromTensor [id T] ''
| | | |Subtensor{int64} [id H] ''
| | |A [id M]
......@@ -208,7 +208,7 @@ def test_scan_debugprint3():
> | | | | | | |k_copy [id BF] -> [id X]
> | | | | | | |Subtensor{int64} [id BJ] ''
> | | | | | | |Shape [id BK] ''
> | | | | | | | |Rebroadcast{0} [id BL] ''
> | | | | | | | |Rebroadcast{(0, False)} [id BL] ''
> | | | | | | | |InplaceDimShuffle{x,0} [id BM] ''
> | | | | | | | |Elemwise{second,no_inplace} [id BN] ''
> | | | | | | | |A_copy [id BO] -> [id W]
......@@ -217,9 +217,9 @@ def test_scan_debugprint3():
> | | | | | | |ScalarConstant{0} [id BR]
> | | | | | |Subtensor{int64} [id BS] ''
> | | | | | |Shape [id BT] ''
> | | | | | | |Rebroadcast{0} [id BL] ''
> | | | | | | |Rebroadcast{(0, False)} [id BL] ''
> | | | | | |ScalarConstant{1} [id BU]
> | | | | |Rebroadcast{0} [id BL] ''
> | | | | |Rebroadcast{(0, False)} [id BL] ''
> | | | | |ScalarFromTensor [id BV] ''
> | | | | |Subtensor{int64} [id BJ] ''
> | | | |A_copy [id BO] -> [id W]
......@@ -341,7 +341,7 @@ def test_scan_debugprint5():
| | | | | | | |k [id G]
| | | | | | | |Subtensor{int64} [id K] ''
| | | | | | | |Shape [id L] ''
| | | | | | | | |Rebroadcast{0} [id M] ''
| | | | | | | | |Rebroadcast{(0, False)} [id M] ''
| | | | | | | | |InplaceDimShuffle{x,0} [id N] ''
| | | | | | | | |Elemwise{second,no_inplace} [id O] ''
| | | | | | | | |A [id P]
......@@ -350,9 +350,9 @@ def test_scan_debugprint5():
| | | | | | | |ScalarConstant{0} [id S]
| | | | | | |Subtensor{int64} [id T] ''
| | | | | | |Shape [id U] ''
| | | | | | | |Rebroadcast{0} [id M] ''
| | | | | | | |Rebroadcast{(0, False)} [id M] ''
| | | | | | |ScalarConstant{1} [id V]
| | | | | |Rebroadcast{0} [id M] ''
| | | | | |Rebroadcast{(0, False)} [id M] ''
| | | | | |ScalarFromTensor [id W] ''
| | | | | |Subtensor{int64} [id K] ''
| | | | |A [id P]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论