提交 b74cf3f6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify string representation of Elemwise and CAReduce

上级 5841c30e
......@@ -234,6 +234,7 @@ class MetaType(ABCMeta):
dct["__eq__"] = __eq__
# FIXME: This overrides __str__ inheritance when props are provided
if "__str__" not in dct:
if len(props) == 0:
......
......@@ -14,7 +14,7 @@ from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
from pytensor.link.c.params_type import ParamsType
from pytensor.misc.frozendict import frozendict
from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import FunctionPrinter, Printer, pprint
from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity
......@@ -498,15 +498,9 @@ class Elemwise(OpenMPOp):
return Apply(self, inputs, outputs)
def __str__(self):
if self.name is None:
if self.inplace_pattern:
items = list(self.inplace_pattern.items())
items.sort()
return f"{type(self).__name__}{{{self.scalar_op}}}{items}"
else:
return f"{type(self).__name__}{{{self.scalar_op}}}"
else:
if self.name:
return self.name
return str(self.scalar_op).capitalize()
def R_op(self, inputs, eval_points):
outs = self(*inputs, return_list=True)
......@@ -1477,23 +1471,17 @@ class CAReduce(COp):
return res
def __str__(self):
prefix = f"{type(self).__name__}{{{self.scalar_op}}}"
extra_params = []
if self.axis is not None:
axis = ", ".join(str(x) for x in self.axis)
extra_params.append(f"axis=[{axis}]")
if self.acc_dtype:
extra_params.append(f"acc_dtype={self.acc_dtype}")
extra_params_str = ", ".join(extra_params)
if extra_params_str:
return f"{prefix}{{{extra_params_str}}}"
def _axis_str(self):
axis = self.axis
if axis is None:
return "axes=None"
elif len(axis) == 1:
return f"axis={axis[0]}"
else:
return f"{prefix}"
return f"axes={list(axis)}"
def __str__(self):
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}"
def perform(self, node, inp, out):
(input,) = inp
......@@ -1737,21 +1725,17 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
symbolname = symbolname or symbol.__name__
if symbolname.endswith("_inplace"):
elemwise_name = f"Elemwise{{{symbolname},inplace}}"
scalar_op = getattr(scalar, symbolname[: -len("_inplace")])
base_symbol_name = symbolname[: -len("_inplace")]
scalar_op = getattr(scalar, base_symbol_name)
inplace_scalar_op = scalar_op.__class__(transfer_type(0))
rval = Elemwise(
inplace_scalar_op,
{0: 0},
name=elemwise_name,
nfunc_spec=(nfunc and (nfunc, nin, nout)),
)
else:
elemwise_name = f"Elemwise{{{symbolname},no_inplace}}"
scalar_op = getattr(scalar, symbolname)
rval = Elemwise(
scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout))
)
rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))
if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__ + "\n\n " + rval.__doc__
......@@ -1761,8 +1745,6 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
rval.__epydoc_asRoutine = symbol
rval.__module__ = symbol.__module__
pprint.assign(rval, FunctionPrinter([symbolname.replace("_inplace", "=")]))
return rval
if symbol:
......
......@@ -583,7 +583,12 @@ def max_and_argmax(a, axis=None, keepdims=False):
return [out, argout]
class NonZeroCAReduce(CAReduce):
class FixedOpCAReduce(CAReduce):
def __str__(self):
return f"{type(self).__name__}{{{self._axis_str()}}}"
class NonZeroDimsCAReduce(FixedOpCAReduce):
def _c_all(self, node, name, inames, onames, sub):
decl, checks, alloc, loop, end = super()._c_all(node, name, inames, onames, sub)
......@@ -614,7 +619,7 @@ class NonZeroCAReduce(CAReduce):
return decl, checks, alloc, loop, end
class Max(NonZeroCAReduce):
class Max(NonZeroDimsCAReduce):
nfunc_spec = ("max", 1, 1)
def __init__(self, axis):
......@@ -625,7 +630,7 @@ class Max(NonZeroCAReduce):
return type(self)(axis=axis)
class Min(NonZeroCAReduce):
class Min(NonZeroDimsCAReduce):
nfunc_spec = ("min", 1, 1)
def __init__(self, axis):
......@@ -1496,7 +1501,7 @@ def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification."""
class Mean(CAReduce):
class Mean(FixedOpCAReduce):
__props__ = ("axis",)
nfunc_spec = ("mean", 1, 1)
......@@ -2356,7 +2361,7 @@ def outer(x, y):
return dot(x.dimshuffle(0, "x"), y.dimshuffle("x", 0))
class All(CAReduce):
class All(FixedOpCAReduce):
"""Applies `logical and` to all the values of a tensor along the
specified axis(es).
......@@ -2370,12 +2375,6 @@ class All(CAReduce):
def _output_dtype(self, idtype):
return "bool"
def __str__(self):
if self.axis is None:
return "All"
else:
return "All{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
input = as_tensor_variable(input)
if input.dtype != "bool":
......@@ -2392,7 +2391,7 @@ class All(CAReduce):
return type(self)(axis=axis)
class Any(CAReduce):
class Any(FixedOpCAReduce):
"""Applies `bitwise or` to all the values of a tensor along the
specified axis(es).
......@@ -2406,12 +2405,6 @@ class Any(CAReduce):
def _output_dtype(self, idtype):
return "bool"
def __str__(self):
if self.axis is None:
return "Any"
else:
return "Any{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input):
input = as_tensor_variable(input)
if input.dtype != "bool":
......@@ -2428,7 +2421,7 @@ class Any(CAReduce):
return type(self)(axis=axis)
class Sum(CAReduce):
class Sum(FixedOpCAReduce):
"""
Sums all the values of a tensor along the specified axis(es).
......@@ -2449,14 +2442,6 @@ class Sum(CAReduce):
upcast_discrete_output=True,
)
def __str__(self):
name = self.__class__.__name__
axis = ""
if self.axis is not None:
axis = ", ".join(str(x) for x in self.axis)
axis = f"axis=[{axis}], "
return f"{name}{{{axis}acc_dtype={self.acc_dtype}}}"
def L_op(self, inp, out, grads):
(x,) = inp
......@@ -2526,7 +2511,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
pprint.assign(Sum, printing.FunctionPrinter(["sum"], ["axis"]))
class Prod(CAReduce):
class Prod(FixedOpCAReduce):
"""
Multiplies all the values of a tensor along the specified axis(es).
......@@ -2537,7 +2522,6 @@ class Prod(CAReduce):
"""
__props__ = ("scalar_op", "axis", "dtype", "acc_dtype", "no_zeros_in_input")
nfunc_spec = ("prod", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False):
......@@ -2683,6 +2667,14 @@ class Prod(CAReduce):
no_zeros_in_input=no_zeros_in_input,
)
def __str__(self):
if self.no_zeros_in_input:
return f"{super().__str__()[:-1]}, no_zeros_in_input}})"
return super().__str__()
def __repr__(self):
return f"{super().__repr__()[:-1]}, no_zeros_in_input={self.no_zeros_in_input})"
def prod(
input,
......@@ -2751,7 +2743,7 @@ class MulWithoutZeros(BinaryScalarOp):
mul_without_zeros = MulWithoutZeros(aes.upcast_out, name="mul_without_zeros")
class ProdWithoutZeros(CAReduce):
class ProdWithoutZeros(FixedOpCAReduce):
def __init__(self, axis=None, dtype=None, acc_dtype=None):
super().__init__(
mul_without_zeros,
......
......@@ -42,7 +42,8 @@ from pytensor.tensor.math import (
All,
Any,
Dot,
NonZeroCAReduce,
FixedOpCAReduce,
NonZeroDimsCAReduce,
Prod,
ProdWithoutZeros,
Sum,
......@@ -1671,7 +1672,8 @@ ALL_REDUCE = (
ProdWithoutZeros,
]
+ CAReduce.__subclasses__()
+ NonZeroCAReduce.__subclasses__()
+ FixedOpCAReduce.__subclasses__()
+ NonZeroDimsCAReduce.__subclasses__()
)
......
......@@ -579,9 +579,9 @@ def test_debugprint():
Inner graphs:
OpFromGraph{inline=False} [id A]
Elemwise{add,no_inplace} [id E]
Add [id E]
├─ *0-<TensorType(float64, (?, ?))> [id F]
└─ Elemwise{mul,no_inplace} [id G]
└─ Mul [id G]
├─ *1-<TensorType(float64, (?, ?))> [id H]
└─ *2-<TensorType(float64, (?, ?))> [id I]
"""
......
......@@ -156,7 +156,7 @@ def test_fgraph_to_python_multiline_str():
assert (
"""
# Elemwise{add,no_inplace}(Test
# Add(Test
# Op().0, Test
# Op().1)
"""
......
......@@ -676,14 +676,9 @@ class TestCAReduce(unittest_tools.InferShapeTester):
def test_str(self):
op = CAReduce(aes.add, axis=None)
assert str(op) == "CAReduce{add}"
assert str(op) == "CAReduce{add, axes=None}"
op = CAReduce(aes.add, axis=(1,))
assert str(op) == "CAReduce{add}{axis=[1]}"
op = CAReduce(aes.add, axis=None, acc_dtype="float64")
assert str(op) == "CAReduce{add}{acc_dtype=float64}"
op = CAReduce(aes.add, axis=(1,), acc_dtype="float64")
assert str(op) == "CAReduce{add}{axis=[1], acc_dtype=float64}"
assert str(op) == "CAReduce{add, axis=1}"
def test_repeated_axis(self):
x = vector("x")
......@@ -802,10 +797,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
self.check_input_dimensions_match(Mode(linker="c"))
def test_str(self):
op = Elemwise(aes.add, inplace_pattern=None, name=None)
assert str(op) == "Elemwise{add}"
op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None)
assert str(op) == "Elemwise{add}[(0, 0)]"
assert str(op) == "Add"
op = Elemwise(aes.add, inplace_pattern=None, name="my_op")
assert str(op) == "my_op"
......
......@@ -106,9 +106,9 @@ def test_min_informative_str():
mis = min_informative_str(G).replace("\t", " ")
reference = """A. Elemwise{add,no_inplace}
reference = """A. Add
B. C
C. Elemwise{add,no_inplace}
C. Add
D. D
E. E"""
......@@ -144,11 +144,11 @@ def test_debugprint():
s = s.getvalue()
reference = dedent(
r"""
Elemwise{add,no_inplace} [id 0]
├─ Elemwise{add,no_inplace} [id 1] 'C'
Add [id 0]
├─ Add [id 1] 'C'
│ ├─ A [id 2]
│ └─ B [id 3]
└─ Elemwise{add,no_inplace} [id 4]
└─ Add [id 4]
├─ D [id 5]
└─ E [id 6]
"""
......@@ -162,11 +162,11 @@ def test_debugprint():
# The additional white space are needed!
reference = dedent(
r"""
Elemwise{add,no_inplace} [id A]
├─ Elemwise{add,no_inplace} [id B] 'C'
Add [id A]
├─ Add [id B] 'C'
│ ├─ A [id C]
│ └─ B [id D]
└─ Elemwise{add,no_inplace} [id E]
└─ Add [id E]
├─ D [id F]
└─ E [id G]
"""
......@@ -180,10 +180,10 @@ def test_debugprint():
# The additional white space are needed!
reference = dedent(
r"""
Elemwise{add,no_inplace} [id A]
├─ Elemwise{add,no_inplace} [id B] 'C'
Add [id A]
├─ Add [id B] 'C'
│ └─ ···
└─ Elemwise{add,no_inplace} [id C]
└─ Add [id C]
├─ D [id D]
└─ E [id E]
"""
......@@ -196,11 +196,11 @@ def test_debugprint():
s = s.getvalue()
reference = dedent(
r"""
Elemwise{add,no_inplace}
├─ Elemwise{add,no_inplace} 'C'
Add
├─ Add 'C'
│ ├─ A
│ └─ B
└─ Elemwise{add,no_inplace}
└─ Add
├─ D
└─ E
"""
......@@ -213,7 +213,7 @@ def test_debugprint():
s = s.getvalue()
reference = dedent(
r"""
Elemwise{add,no_inplace} 0 [None]
Add 0 [None]
├─ A [None]
├─ B [None]
├─ D [None]
......@@ -231,7 +231,7 @@ def test_debugprint():
s = s.getvalue()
reference = dedent(
r"""
Elemwise{add,no_inplace} 0 [None]
Add 0 [None]
├─ A [None]
├─ B [None]
├─ D [None]
......@@ -249,7 +249,7 @@ def test_debugprint():
s = s.getvalue()
reference = dedent(
r"""
Elemwise{add,no_inplace} 0 [None]
Add 0 [None]
├─ A [None]
├─ B [None]
├─ D [None]
......@@ -274,7 +274,7 @@ def test_debugprint():
s = s.getvalue()
exp_res = dedent(
r"""
Elemwise{Composite{(i2 + (i0 - i1))}} 4
Composite{(i2 + (i0 - i1))} 4
├─ ExpandDims{axis=0} v={0: [0]} 3
│ └─ CGemv{inplace} d={0: [0]} 2
│ ├─ AllocEmpty{dtype='float64'} 1
......@@ -289,7 +289,7 @@ def test_debugprint():
Inner graphs:
Elemwise{Composite{(i2 + (i0 - i1))}}
Composite{(i2 + (i0 - i1))}
← add 'o0'
├─ i2
└─ sub
......@@ -314,7 +314,7 @@ def test_debugprint_id_type():
debugprint(e_at, id_type="auto", file=s)
s = s.getvalue()
exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}]
exp_res = f"""Add [id {e_at.auto_name}]
├─ dot [id {d_at.auto_name}]
│ ├─ <TensorType(float64, (?, ?))> [id {b_at.auto_name}]
│ └─ <TensorType(float64, (?,))> [id {a_at.auto_name}]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论