提交 7b58b711 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a contextmanager for setting the printer precedence

上级 8d3a512a
......@@ -9,6 +9,7 @@ import logging
import os
import sys
from abc import ABC, abstractmethod
from contextlib import contextmanager
from copy import copy
from functools import reduce
from io import IOBase, StringIO
......@@ -665,6 +666,17 @@ class Printer(ABC):
"""Construct a string representation for a `Variable`."""
@contextmanager
def set_precedence(pstate: PrinterState, precedence: int = -1000):
"""Temporarily set the precedence of a `PrinterState`."""
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = precedence
try:
yield
finally:
pstate.precedence = old_precedence
class OperatorPrinter(Printer):
def __init__(self, operator, precedence, assoc="left"):
self.operator = operator
......@@ -699,12 +711,10 @@ class OperatorPrinter(Printer):
new_precedence = self.precedence
if self.assoc == "left" and i != 0 or self.assoc == "right" and i != max_i:
new_precedence += 1e-6
try:
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence
with set_precedence(pstate, new_precedence):
s = pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
input_strings.append(s)
if len(input_strings) == 1:
s = self.operator + input_strings[0]
......@@ -742,13 +752,8 @@ class PatternPrinter(Printer):
precedences += (1000,) * len(node.inputs)
def pp_process(input, new_precedence):
try:
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence
with set_precedence(pstate, new_precedence):
r = pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
return r
d = {
......@@ -792,10 +797,7 @@ class FunctionPrinter(Printer):
)
idx = node.outputs.index(output)
name = self.names[idx]
new_precedence = -1000
try:
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence
with set_precedence(pstate):
inputs_str = ", ".join(
[pprinter.process(input, pstate) for input in node.inputs]
)
......@@ -807,8 +809,6 @@ class FunctionPrinter(Printer):
keywords_str = f", {keywords_str}"
r = f"{name}({inputs_str}{keywords_str})"
finally:
pstate.precedence = old_precedence
pstate.memo[output] = r
return r
......@@ -866,16 +866,11 @@ class DefaultPrinter(Printer):
node = output.owner
if node is None:
return leaf_printer.process(output, pstate)
new_precedence = -1000
try:
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence
with set_precedence(pstate):
r = "{}({})".format(
str(node.op),
", ".join([pprinter.process(input, pstate) for input in node.inputs]),
)
finally:
pstate.precedence = old_precedence
pstate.memo[output] = r
return r
......
......@@ -38,7 +38,7 @@ from aesara.graph.utils import (
TestValueError,
get_variable_trace_string,
)
from aesara.printing import Printer, pprint
from aesara.printing import Printer, pprint, set_precedence
from aesara.raise_op import Assert, CheckAndRaise, assert_op
from aesara.tensor.basic import (
Alloc,
......@@ -723,12 +723,8 @@ class MakeVectorPrinter(Printer):
if r.owner is None:
raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector):
old_precedence = getattr(pstate, "precedence", None)
try:
pstate.precedence = 1000
with set_precedence(pstate):
s = [pstate.pprinter.process(inp) for inp in r.owner.inputs]
finally:
pstate.precedence = old_precedence
return f"[{', '.join(s)}]"
else:
raise TypeError("Can only print make_vector.")
......
......@@ -16,7 +16,7 @@ from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type
from aesara.graph.utils import MethodNotDefined
from aesara.misc.safe_asarray import _asarray
from aesara.printing import Printer, pprint
from aesara.printing import Printer, pprint, set_precedence
from aesara.scalar.basic import ScalarConstant
from aesara.tensor import _get_vector_length, get_vector_length
from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value
......@@ -1196,14 +1196,11 @@ class SubtensorPrinter(Printer):
inputs = list(op_inputs)
input = inputs.pop(0)
sidxs = []
old_precedence = getattr(pstate, "precedence", None)
getattr(pstate, "precedence", None)
for entry in idxs:
if isinstance(entry, aes.Scalar):
pstate.precedence = -1000
try:
with set_precedence(pstate):
sidxs.append(pstate.pprinter.process(inputs.pop()))
finally:
pstate.precedence = old_precedence
elif isinstance(entry, slice):
if entry.start is None or entry.start == 0:
msg1 = ""
......@@ -1222,11 +1219,9 @@ class SubtensorPrinter(Printer):
sidxs.append(f"{msg1}:{msg2}{msg3}")
try:
pstate.precedence = 1000
with set_precedence(pstate, 1000):
sub = pstate.pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
return f"{sub}[{', '.join(sidxs)}]"
......@@ -1839,12 +1834,8 @@ class IncSubtensorPrinter(SubtensorPrinter):
res = self._process(r.owner.op.idx_list, [x] + idx_args, pstate)
old_precedence = pstate.precedence
try:
pstate.precedence = 1000
with set_precedence(pstate, 1000):
y_str = pstate.pprinter.process(r.owner.inputs[1], pstate)
finally:
pstate.precedence = old_precedence
if r.owner.op.set_instead_of_inc:
res = f"set_subtensor({res}, {y_str})"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论