提交 61b1bbb9 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add keywords to FunctionPrinter

上级 aec01eed
...@@ -764,9 +764,22 @@ class PatternPrinter(Printer): ...@@ -764,9 +764,22 @@ class PatternPrinter(Printer):
class FunctionPrinter(Printer): class FunctionPrinter(Printer):
def __init__(self, *names): def __init__(self, names: List[str], keywords: Optional[List[str]] = None):
"""
Parameters
----------
names
The function names used for each output.
keywords
The `Op` keywords to include in the output.
"""
self.names = names self.names = names
if keywords is None:
keywords = []
self.keywords = keywords
def process(self, output, pstate): def process(self, output, pstate):
if output in pstate.memo: if output in pstate.memo:
return pstate.memo[output] return pstate.memo[output]
...@@ -783,10 +796,17 @@ class FunctionPrinter(Printer): ...@@ -783,10 +796,17 @@ class FunctionPrinter(Printer):
try: try:
old_precedence = getattr(pstate, "precedence", None) old_precedence = getattr(pstate, "precedence", None)
pstate.precedence = new_precedence pstate.precedence = new_precedence
r = "{}({})".format( inputs_str = ", ".join(
name, [pprinter.process(input, pstate) for input in node.inputs]
", ".join([pprinter.process(input, pstate) for input in node.inputs]),
) )
keywords_str = ", ".join(
[f"{kw}={getattr(node.op, kw)}" for kw in self.keywords]
)
if keywords_str and inputs_str:
keywords_str = f", {keywords_str}"
r = f"{name}({inputs_str}{keywords_str})"
finally: finally:
pstate.precedence = old_precedence pstate.precedence = old_precedence
......
...@@ -815,7 +815,7 @@ register_rebroadcast_c_code( ...@@ -815,7 +815,7 @@ register_rebroadcast_c_code(
def _conversion(real_value, name): def _conversion(real_value, name):
__oplist_tag(real_value, "casting") __oplist_tag(real_value, "casting")
real_value.__module__ = "tensor.basic" real_value.__module__ = "tensor.basic"
pprint.assign(real_value, printing.FunctionPrinter(name)) pprint.assign(real_value, printing.FunctionPrinter([name]))
return real_value return real_value
...@@ -927,7 +927,7 @@ def second(a, b): ...@@ -927,7 +927,7 @@ def second(a, b):
fill = second fill = second
pprint.assign(fill, printing.FunctionPrinter("fill")) pprint.assign(fill, printing.FunctionPrinter(["fill"]))
def ones_like(model, dtype=None, opt=False): def ones_like(model, dtype=None, opt=False):
...@@ -1547,7 +1547,7 @@ class Alloc(COp): ...@@ -1547,7 +1547,7 @@ class Alloc(COp):
alloc = Alloc() alloc = Alloc()
pprint.assign(alloc, printing.FunctionPrinter("alloc")) pprint.assign(alloc, printing.FunctionPrinter(["alloc"]))
def full(shape, fill_value, dtype=None): def full(shape, fill_value, dtype=None):
...@@ -2514,7 +2514,7 @@ class Join(COp): ...@@ -2514,7 +2514,7 @@ class Join(COp):
join_ = Join() join_ = Join()
pprint.assign(Join, printing.FunctionPrinter("join")) pprint.assign(Join, printing.FunctionPrinter(["join"]))
@_get_vector_length.register(Join) @_get_vector_length.register(Join)
......
...@@ -1098,8 +1098,8 @@ gemm_inplace = Gemm(inplace=True) ...@@ -1098,8 +1098,8 @@ gemm_inplace = Gemm(inplace=True)
gemm_no_inplace = Gemm(inplace=False) gemm_no_inplace = Gemm(inplace=False)
# For the user interface. Aesara optimization will make them inplace # For the user interface. Aesara optimization will make them inplace
gemm = gemm_no_inplace gemm = gemm_no_inplace
pprint.assign(gemm_inplace, FunctionPrinter("gemm_inplace")) pprint.assign(gemm_inplace, FunctionPrinter(["gemm_inplace"]))
pprint.assign(gemm_no_inplace, FunctionPrinter("gemm_no_inplace")) pprint.assign(gemm_no_inplace, FunctionPrinter(["gemm_no_inplace"]))
def res_is_a(fgraph, var, op, maxclients=None): def res_is_a(fgraph, var, op, maxclients=None):
......
...@@ -1833,7 +1833,7 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): ...@@ -1833,7 +1833,7 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
rval.__epydoc_asRoutine = symbol rval.__epydoc_asRoutine = symbol
rval.__module__ = symbol.__module__ rval.__module__ = symbol.__module__
pprint.assign(rval, FunctionPrinter(symbolname.replace("_inplace", "="))) pprint.assign(rval, FunctionPrinter([symbolname.replace("_inplace", "=")]))
return rval return rval
......
...@@ -334,7 +334,7 @@ def second_inplace(a): ...@@ -334,7 +334,7 @@ def second_inplace(a):
fill_inplace = second_inplace fill_inplace = second_inplace
pprint.assign(fill_inplace, printing.FunctionPrinter("fill=")) pprint.assign(fill_inplace, printing.FunctionPrinter(["fill="]))
@scalar_elemwise(symbolname="scalar_maximum_inplace") @scalar_elemwise(symbolname="scalar_maximum_inplace")
......
...@@ -2495,7 +2495,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): ...@@ -2495,7 +2495,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
return out return out
pprint.assign(Sum(), printing.FunctionPrinter("sum")) pprint.assign(Sum, printing.FunctionPrinter(["sum"], ["axis"]))
class Prod(CAReduceDtype): class Prod(CAReduceDtype):
......
...@@ -95,7 +95,7 @@ ultra_fast_sigmoid_inplace = Elemwise( ...@@ -95,7 +95,7 @@ ultra_fast_sigmoid_inplace = Elemwise(
name="ultra_fast_sigmoid_inplace", name="ultra_fast_sigmoid_inplace",
) )
pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter("ultra_fast_sigmoid")) pprint.assign(ultra_fast_sigmoid, printing.FunctionPrinter(["ultra_fast_sigmoid"]))
# @opt.register_uncanonicalize # @opt.register_uncanonicalize
......
...@@ -22,6 +22,7 @@ from aesara.graph.basic import Variable, applys_between ...@@ -22,6 +22,7 @@ from aesara.graph.basic import Variable, applys_between
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.link.c.basic import DualLinker from aesara.link.c.basic import DualLinker
from aesara.misc.safe_asarray import _asarray from aesara.misc.safe_asarray import _asarray
from aesara.printing import pprint
from aesara.tensor import blas, blas_c from aesara.tensor import blas, blas_c
from aesara.tensor.basic import ( from aesara.tensor.basic import (
as_tensor_variable, as_tensor_variable,
...@@ -3336,3 +3337,9 @@ def test_logsumexp(shape, axis, keepdims): ...@@ -3336,3 +3337,9 @@ def test_logsumexp(shape, axis, keepdims):
aesara_out, aesara_out,
scipy_out, scipy_out,
) )
def test_pprint():
x = vector("x")
y = aet_sum(x, axis=0)
assert pprint(y) == "sum(x, axis=(0,))"
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论