提交 aec01eed authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clarify printer interface with Printer ABC

上级 2ca41998
......@@ -8,6 +8,7 @@ import hashlib
import logging
import os
import sys
from abc import ABC, abstractmethod
from copy import copy
from functools import reduce
from io import IOBase, StringIO
......@@ -658,7 +659,13 @@ class PrinterState(Scratchpad):
self.memo = {}
class OperatorPrinter:
class Printer(ABC):
@abstractmethod
def process(self, var: Variable, pstate: PrinterState) -> str:
"""Construct a string representation for a `Variable`."""
class OperatorPrinter(Printer):
def __init__(self, operator, precedence, assoc="left"):
self.operator = operator
self.precedence = precedence
......@@ -711,7 +718,7 @@ class OperatorPrinter:
return r
class PatternPrinter:
class PatternPrinter(Printer):
def __init__(self, *patterns):
self.patterns = []
for pattern in patterns:
......@@ -756,7 +763,7 @@ class PatternPrinter:
return r
class FunctionPrinter:
class FunctionPrinter(Printer):
def __init__(self, *names):
self.names = names
......@@ -787,7 +794,7 @@ class FunctionPrinter:
return r
class IgnorePrinter:
class IgnorePrinter(Printer):
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
......@@ -804,7 +811,7 @@ class IgnorePrinter:
return r
class LeafPrinter:
class LeafPrinter(Printer):
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
......@@ -819,7 +826,7 @@ class LeafPrinter:
leaf_printer = LeafPrinter()
class DefaultPrinter:
class DefaultPrinter(Printer):
def process(self, output, pstate):
if output in pstate.memo:
return pstate.memo[output]
......@@ -845,7 +852,7 @@ class DefaultPrinter:
default_printer = DefaultPrinter()
class PPrinter:
class PPrinter(Printer):
def __init__(self):
self.printers = []
self.printers_dict = {}
......
......@@ -38,7 +38,7 @@ from aesara.graph.utils import (
TestValueError,
get_variable_trace_string,
)
from aesara.printing import pprint
from aesara.printing import Printer, pprint
from aesara.raise_op import Assert, CheckAndRaise, assert_op
from aesara.tensor.basic import (
Alloc,
......@@ -718,7 +718,7 @@ def local_scalar_tensor_scalar(fgraph, node):
return [s]
class MakeVectorPrinter:
class MakeVectorPrinter(Printer):
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print make_vector.")
......
......@@ -14,7 +14,7 @@ from aesara.graph.utils import MethodNotDefined
from aesara.link.c.basic import failure_code
from aesara.misc.frozendict import frozendict
from aesara.misc.safe_asarray import _asarray
from aesara.printing import FunctionPrinter, pprint
from aesara.printing import FunctionPrinter, Printer, pprint
from aesara.scalar import get_scalar_type
from aesara.scalar.basic import Scalar
from aesara.scalar.basic import bool as scalar_bool
......@@ -300,7 +300,7 @@ class DimShuffle(ExternalCOp):
]
class DimShufflePrinter:
class DimShufflePrinter(Printer):
def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == "x":
return f"{self.__p(new_order[1:], pstate, r)}"
......
......@@ -16,7 +16,7 @@ from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type
from aesara.graph.utils import MethodNotDefined
from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint
from aesara.printing import Printer, pprint
from aesara.scalar.basic import ScalarConstant
from aesara.tensor import _get_vector_length, get_vector_length
from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value
......@@ -1188,7 +1188,7 @@ class Subtensor(COp):
return self(eval_points[0], *inputs[1:], return_list=True)
class SubtensorPrinter:
class SubtensorPrinter(Printer):
def process(self, r, pstate):
if r.owner is None:
raise TypeError("Can only print Subtensor.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论