提交 7b58b711 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add a contextmanager for setting the printer precedence

上级 8d3a512a
...@@ -9,6 +9,7 @@ import logging ...@@ -9,6 +9,7 @@ import logging
import os import os
import sys import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager
from copy import copy from copy import copy
from functools import reduce from functools import reduce
from io import IOBase, StringIO from io import IOBase, StringIO
...@@ -665,6 +666,17 @@ class Printer(ABC): ...@@ -665,6 +666,17 @@ class Printer(ABC):
"""Construct a string representation for a `Variable`.""" """Construct a string representation for a `Variable`."""
@contextmanager
def set_precedence(pstate: PrinterState, precedence: int = -1000):
"""Temporarily set the precedence of a `PrinterState`."""
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = precedence
try:
yield
finally:
pstate.precedence = old_precedence
class OperatorPrinter(Printer): class OperatorPrinter(Printer):
def __init__(self, operator, precedence, assoc="left"): def __init__(self, operator, precedence, assoc="left"):
self.operator = operator self.operator = operator
...@@ -699,12 +711,10 @@ class OperatorPrinter(Printer): ...@@ -699,12 +711,10 @@ class OperatorPrinter(Printer):
new_precedence = self.precedence new_precedence = self.precedence
if self.assoc == "left" and i != 0 or self.assoc == "right" and i != max_i: if self.assoc == "left" and i != 0 or self.assoc == "right" and i != max_i:
new_precedence += 1e-6 new_precedence += 1e-6
try:
old_precedence = getattr(pstate, "precedence", None) with set_precedence(pstate, new_precedence):
pstate.precedence = new_precedence
s = pprinter.process(input, pstate) s = pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
input_strings.append(s) input_strings.append(s)
if len(input_strings) == 1: if len(input_strings) == 1:
s = self.operator + input_strings[0] s = self.operator + input_strings[0]
...@@ -742,13 +752,8 @@ class PatternPrinter(Printer): ...@@ -742,13 +752,8 @@ class PatternPrinter(Printer):
precedences += (1000,) * len(node.inputs) precedences += (1000,) * len(node.inputs)
def pp_process(input, new_precedence): def pp_process(input, new_precedence):
try: with set_precedence(pstate, new_precedence):
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence
r = pprinter.process(input, pstate) r = pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
return r return r
d = { d = {
...@@ -792,10 +797,7 @@ class FunctionPrinter(Printer): ...@@ -792,10 +797,7 @@ class FunctionPrinter(Printer):
) )
idx = node.outputs.index(output) idx = node.outputs.index(output)
name = self.names[idx] name = self.names[idx]
new_precedence = -1000 with set_precedence(pstate):
try:
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence
inputs_str = ", ".join( inputs_str = ", ".join(
[pprinter.process(input, pstate) for input in node.inputs] [pprinter.process(input, pstate) for input in node.inputs]
) )
...@@ -807,8 +809,6 @@ class FunctionPrinter(Printer): ...@@ -807,8 +809,6 @@ class FunctionPrinter(Printer):
keywords_str = f", {keywords_str}" keywords_str = f", {keywords_str}"
r = f"{name}({inputs_str}{keywords_str})" r = f"{name}({inputs_str}{keywords_str})"
finally:
pstate.precedence = old_precedence
pstate.memo[output] = r pstate.memo[output] = r
return r return r
...@@ -866,16 +866,11 @@ class DefaultPrinter(Printer): ...@@ -866,16 +866,11 @@ class DefaultPrinter(Printer):
node = output.owner node = output.owner
if node is None: if node is None:
return leaf_printer.process(output, pstate) return leaf_printer.process(output, pstate)
new_precedence = -1000 with set_precedence(pstate):
try:
old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence
r = "{}({})".format( r = "{}({})".format(
str(node.op), str(node.op),
", ".join([pprinter.process(input, pstate) for input in node.inputs]), ", ".join([pprinter.process(input, pstate) for input in node.inputs]),
) )
finally:
pstate.precedence = old_precedence
pstate.memo[output] = r pstate.memo[output] = r
return r return r
......
...@@ -38,7 +38,7 @@ from aesara.graph.utils import ( ...@@ -38,7 +38,7 @@ from aesara.graph.utils import (
TestValueError, TestValueError,
get_variable_trace_string, get_variable_trace_string,
) )
from aesara.printing import Printer, pprint from aesara.printing import Printer, pprint, set_precedence
from aesara.raise_op import Assert, CheckAndRaise, assert_op from aesara.raise_op import Assert, CheckAndRaise, assert_op
from aesara.tensor.basic import ( from aesara.tensor.basic import (
Alloc, Alloc,
...@@ -723,12 +723,8 @@ class MakeVectorPrinter(Printer): ...@@ -723,12 +723,8 @@ class MakeVectorPrinter(Printer):
if r.owner is None: if r.owner is None:
raise TypeError("Can only print make_vector.") raise TypeError("Can only print make_vector.")
elif isinstance(r.owner.op, MakeVector): elif isinstance(r.owner.op, MakeVector):
old_precedence = getattr(pstate, "precedence", None) with set_precedence(pstate):
try:
pstate.precedence = 1000
s = [pstate.pprinter.process(inp) for inp in r.owner.inputs] s = [pstate.pprinter.process(inp) for inp in r.owner.inputs]
finally:
pstate.precedence = old_precedence
return f"[{', '.join(s)}]" return f"[{', '.join(s)}]"
else: else:
raise TypeError("Can only print make_vector.") raise TypeError("Can only print make_vector.")
......
...@@ -16,7 +16,7 @@ from aesara.graph.params_type import ParamsType ...@@ -16,7 +16,7 @@ from aesara.graph.params_type import ParamsType
from aesara.graph.type import Type from aesara.graph.type import Type
from aesara.graph.utils import MethodNotDefined from aesara.graph.utils import MethodNotDefined
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import Printer, pprint from aesara.printing import Printer, pprint, set_precedence
from aesara.scalar.basic import ScalarConstant from aesara.scalar.basic import ScalarConstant
from aesara.tensor import _get_vector_length, get_vector_length from aesara.tensor import _get_vector_length, get_vector_length
from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value from aesara.tensor.basic import addbroadcast, alloc, get_scalar_constant_value
...@@ -1196,14 +1196,11 @@ class SubtensorPrinter(Printer): ...@@ -1196,14 +1196,11 @@ class SubtensorPrinter(Printer):
inputs = list(op_inputs) inputs = list(op_inputs)
input = inputs.pop(0) input = inputs.pop(0)
sidxs = [] sidxs = []
old_precedence = getattr(pstate, "precedence", None) getattr(pstate, "precedence", None)
for entry in idxs: for entry in idxs:
if isinstance(entry, aes.Scalar): if isinstance(entry, aes.Scalar):
pstate.precedence = -1000 with set_precedence(pstate):
try:
sidxs.append(pstate.pprinter.process(inputs.pop())) sidxs.append(pstate.pprinter.process(inputs.pop()))
finally:
pstate.precedence = old_precedence
elif isinstance(entry, slice): elif isinstance(entry, slice):
if entry.start is None or entry.start == 0: if entry.start is None or entry.start == 0:
msg1 = "" msg1 = ""
...@@ -1222,11 +1219,9 @@ class SubtensorPrinter(Printer): ...@@ -1222,11 +1219,9 @@ class SubtensorPrinter(Printer):
sidxs.append(f"{msg1}:{msg2}{msg3}") sidxs.append(f"{msg1}:{msg2}{msg3}")
try: with set_precedence(pstate, 1000):
pstate.precedence = 1000
sub = pstate.pprinter.process(input, pstate) sub = pstate.pprinter.process(input, pstate)
finally:
pstate.precedence = old_precedence
return f"{sub}[{', '.join(sidxs)}]" return f"{sub}[{', '.join(sidxs)}]"
...@@ -1839,12 +1834,8 @@ class IncSubtensorPrinter(SubtensorPrinter): ...@@ -1839,12 +1834,8 @@ class IncSubtensorPrinter(SubtensorPrinter):
res = self._process(r.owner.op.idx_list, [x] + idx_args, pstate) res = self._process(r.owner.op.idx_list, [x] + idx_args, pstate)
old_precedence = pstate.precedence with set_precedence(pstate, 1000):
try:
pstate.precedence = 1000
y_str = pstate.pprinter.process(r.owner.inputs[1], pstate) y_str = pstate.pprinter.process(r.owner.inputs[1], pstate)
finally:
pstate.precedence = old_precedence
if r.owner.op.set_instead_of_inc: if r.owner.op.set_instead_of_inc:
res = f"set_subtensor({res}, {y_str})" res = f"set_subtensor({res}, {y_str})"
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论