提交 aec01eed authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Clarify printer interface with Printer ABC

上级 2ca41998
...@@ -8,6 +8,7 @@ import hashlib ...@@ -8,6 +8,7 @@ import hashlib
import logging import logging
import os import os
import sys import sys
from abc import ABC, abstractmethod
from copy import copy from copy import copy
from functools import reduce from functools import reduce
from io import IOBase, StringIO from io import IOBase, StringIO
...@@ -658,7 +659,13 @@ class PrinterState(Scratchpad): ...@@ -658,7 +659,13 @@ class PrinterState(Scratchpad):
self.memo = {} self.memo = {}
class OperatorPrinter: class Printer(ABC):
@abstractmethod
def process(self, var: Variable, pstate: PrinterState) -> str:
"""Construct a string representation for a `Variable`."""
class OperatorPrinter(Printer):
def __init__(self, operator, precedence, assoc="left"): def __init__(self, operator, precedence, assoc="left"):
self.operator = operator self.operator = operator
self.precedence = precedence self.precedence = precedence
...@@ -711,7 +718,7 @@ class OperatorPrinter: ...@@ -711,7 +718,7 @@ class OperatorPrinter:
return r return r
class PatternPrinter: class PatternPrinter(Printer):
def __init__(self, *patterns): def __init__(self, *patterns):
self.patterns = [] self.patterns = []
for pattern in patterns: for pattern in patterns:
...@@ -756,7 +763,7 @@ class PatternPrinter: ...@@ -756,7 +763,7 @@ class PatternPrinter:
return r return r
class FunctionPrinter: class FunctionPrinter(Printer):
def __init__(self, *names): def __init__(self, *names):
self.names = names self.names = names
...@@ -787,7 +794,7 @@ class FunctionPrinter: ...@@ -787,7 +794,7 @@ class FunctionPrinter:
return r return r
class IgnorePrinter: class IgnorePrinter(Printer):
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo: if output in pstate.memo:
return pstate.memo[output] return pstate.memo[output]
...@@ -804,7 +811,7 @@ class IgnorePrinter: ...@@ -804,7 +811,7 @@ class IgnorePrinter:
return r return r
class LeafPrinter: class LeafPrinter(Printer):
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo: if output in pstate.memo:
return pstate.memo[output] return pstate.memo[output]
...@@ -819,7 +826,7 @@ class LeafPrinter: ...@@ -819,7 +826,7 @@ class LeafPrinter:
leaf_printer = LeafPrinter() leaf_printer = LeafPrinter()
class DefaultPrinter: class DefaultPrinter(Printer):
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo: if output in pstate.memo:
return pstate.memo[output] return pstate.memo[output]
...@@ -845,7 +852,7 @@ class DefaultPrinter: ...@@ -845,7 +852,7 @@ class DefaultPrinter:
default_printer = DefaultPrinter() default_printer = DefaultPrinter()
class PPrinter: class PPrinter(Printer):
def __init__(self): def __init__(self):
self.printers = [] self.printers = []
self.printers_dict = {} self.printers_dict = {}
......
...@@ -38,7 +38,7 @@ from aesara.graph.utils import ( ...@@ -38,7 +38,7 @@ from aesara.graph.utils import (
TestValueError, TestValueError,
get_variable_trace_string, get_variable_trace_string,
) )
from aesara.printing import pprint from aesara.printing import Printer, pprint
from aesara.raise_op import Assert, CheckAndRaise, assert_op from aesara.raise_op import Assert, CheckAndRaise, assert_op
from aesara.tensor.basic import ( from aesara.tensor.basic import (
Alloc, Alloc,
...@@ -718,7 +718,7 @@ def local_scalar_tensor_scalar(fgraph, node): ...@@ -718,7 +718,7 @@ def local_scalar_tensor_scalar(fgraph, node):
return [s] return [s]
class MakeVectorPrinter: class MakeVectorPrinter(Printer):
def process(self, r, pstate): def process(self, r, pstate):
if r.owner is None: if r.owner is None:
raise TypeError("Can only print make_vector.") raise TypeError("Can only print make_vector.")
......
...@@ -14,7 +14,7 @@ from aesara.graph.utils import MethodNotDefined ...@@ -14,7 +14,7 @@ from aesara.graph.utils import MethodNotDefined
from aesara.link.c.basic import failure_code from aesara.link.c.basic import failure_code
from aesara.misc.frozendict import frozendict from aesara.misc.frozendict import frozendict
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import FunctionPrinter, pprint from aesara.printing import FunctionPrinter, Printer, pprint
from aesara.scalar import get_scalar_type from aesara.scalar import get_scalar_type
from aesara.scalar.basic import Scalar from aesara.scalar.basic import Scalar
from aesara.scalar.basic import bool as scalar_bool from aesara.scalar.basic import bool as scalar_bool
...@@ -300,7 +300,7 @@ class DimShuffle(ExternalCOp): ...@@ -300,7 +300,7 @@ class DimShuffle(ExternalCOp):
] ]
class DimShufflePrinter: class DimShufflePrinter(Printer):
def __p(self, new_order, pstate, r): def __p(self, new_order, pstate, r):
if new_order != () and new_order[0] == "x": if new_order != () and new_order[0] == "x":
return f"{self.__p(new_order[1:], pstate, r)}" return f"{self.__p(new_order[1:], pstate, r)}"
......
...@@ -16,7 +16,7 @@ from aesara.graph.params_type import ParamsType ...@@ -16,7 +16,7 @@ from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint from aesara.printing import Printer, pprint
from aesara.scalar.basic import ScalarConstant from aesara.scalar.basic import ScalarConstant
from aesara.tensor import _get_vector_length, get_vector_length from aesara.tensor import _get_vector_length, get_vector_length
from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value
...@@ -1188,7 +1188,7 @@ class Subtensor(COp): ...@@ -1188,7 +1188,7 @@ class Subtensor(COp):
return self(eval_points[0], *inputs[1:], return_list=True) return self(eval_points[0], *inputs[1:], return_list=True)
class SubtensorPrinter: class SubtensorPrinter(Printer):
def process(self, r, pstate): def process(self, r, pstate):
if r.owner is None: if r.owner is None:
raise TypeError("Can only print Subtensor.") raise TypeError("Can only print Subtensor.")
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论