提交 b74cf3f6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify string representation of Elemwise and CAReduce

上级 5841c30e
...@@ -234,6 +234,7 @@ class MetaType(ABCMeta): ...@@ -234,6 +234,7 @@ class MetaType(ABCMeta):
dct["__eq__"] = __eq__ dct["__eq__"] = __eq__
# FIXME: This overrides __str__ inheritance when props are provided
if "__str__" not in dct: if "__str__" not in dct:
if len(props) == 0: if len(props) == 0:
......
...@@ -14,7 +14,7 @@ from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp ...@@ -14,7 +14,7 @@ from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.misc.frozendict import frozendict from pytensor.misc.frozendict import frozendict
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import FunctionPrinter, Printer, pprint from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity from pytensor.scalar.basic import identity as scalar_identity
...@@ -498,15 +498,9 @@ class Elemwise(OpenMPOp): ...@@ -498,15 +498,9 @@ class Elemwise(OpenMPOp):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def __str__(self): def __str__(self):
if self.name is None: if self.name:
if self.inplace_pattern:
items = list(self.inplace_pattern.items())
items.sort()
return f"{type(self).__name__}{{{self.scalar_op}}}{items}"
else:
return f"{type(self).__name__}{{{self.scalar_op}}}"
else:
return self.name return self.name
return str(self.scalar_op).capitalize()
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
outs = self(*inputs, return_list=True) outs = self(*inputs, return_list=True)
...@@ -1477,23 +1471,17 @@ class CAReduce(COp): ...@@ -1477,23 +1471,17 @@ class CAReduce(COp):
return res return res
def __str__(self): def _axis_str(self):
prefix = f"{type(self).__name__}{{{self.scalar_op}}}" axis = self.axis
extra_params = [] if axis is None:
return "axes=None"
if self.axis is not None: elif len(axis) == 1:
axis = ", ".join(str(x) for x in self.axis) return f"axis={axis[0]}"
extra_params.append(f"axis=[{axis}]")
if self.acc_dtype:
extra_params.append(f"acc_dtype={self.acc_dtype}")
extra_params_str = ", ".join(extra_params)
if extra_params_str:
return f"{prefix}{{{extra_params_str}}}"
else: else:
return f"{prefix}" return f"axes={list(axis)}"
def __str__(self):
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}"
def perform(self, node, inp, out): def perform(self, node, inp, out):
(input,) = inp (input,) = inp
...@@ -1737,21 +1725,17 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): ...@@ -1737,21 +1725,17 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
symbolname = symbolname or symbol.__name__ symbolname = symbolname or symbol.__name__
if symbolname.endswith("_inplace"): if symbolname.endswith("_inplace"):
elemwise_name = f"Elemwise{{{symbolname},inplace}}" base_symbol_name = symbolname[: -len("_inplace")]
scalar_op = getattr(scalar, symbolname[: -len("_inplace")]) scalar_op = getattr(scalar, base_symbol_name)
inplace_scalar_op = scalar_op.__class__(transfer_type(0)) inplace_scalar_op = scalar_op.__class__(transfer_type(0))
rval = Elemwise( rval = Elemwise(
inplace_scalar_op, inplace_scalar_op,
{0: 0}, {0: 0},
name=elemwise_name,
nfunc_spec=(nfunc and (nfunc, nin, nout)), nfunc_spec=(nfunc and (nfunc, nin, nout)),
) )
else: else:
elemwise_name = f"Elemwise{{{symbolname},no_inplace}}"
scalar_op = getattr(scalar, symbolname) scalar_op = getattr(scalar, symbolname)
rval = Elemwise( rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))
scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout))
)
if getattr(symbol, "__doc__"): if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__ + "\n\n " + rval.__doc__ rval.__doc__ = symbol.__doc__ + "\n\n " + rval.__doc__
...@@ -1761,8 +1745,6 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): ...@@ -1761,8 +1745,6 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
rval.__epydoc_asRoutine = symbol rval.__epydoc_asRoutine = symbol
rval.__module__ = symbol.__module__ rval.__module__ = symbol.__module__
pprint.assign(rval, FunctionPrinter([symbolname.replace("_inplace", "=")]))
return rval return rval
if symbol: if symbol:
......
...@@ -583,7 +583,12 @@ def max_and_argmax(a, axis=None, keepdims=False): ...@@ -583,7 +583,12 @@ def max_and_argmax(a, axis=None, keepdims=False):
return [out, argout] return [out, argout]
class NonZeroCAReduce(CAReduce): class FixedOpCAReduce(CAReduce):
def __str__(self):
return f"{type(self).__name__}{{{self._axis_str()}}}"
class NonZeroDimsCAReduce(FixedOpCAReduce):
def _c_all(self, node, name, inames, onames, sub): def _c_all(self, node, name, inames, onames, sub):
decl, checks, alloc, loop, end = super()._c_all(node, name, inames, onames, sub) decl, checks, alloc, loop, end = super()._c_all(node, name, inames, onames, sub)
...@@ -614,7 +619,7 @@ class NonZeroCAReduce(CAReduce): ...@@ -614,7 +619,7 @@ class NonZeroCAReduce(CAReduce):
return decl, checks, alloc, loop, end return decl, checks, alloc, loop, end
class Max(NonZeroCAReduce): class Max(NonZeroDimsCAReduce):
nfunc_spec = ("max", 1, 1) nfunc_spec = ("max", 1, 1)
def __init__(self, axis): def __init__(self, axis):
...@@ -625,7 +630,7 @@ class Max(NonZeroCAReduce): ...@@ -625,7 +630,7 @@ class Max(NonZeroCAReduce):
return type(self)(axis=axis) return type(self)(axis=axis)
class Min(NonZeroCAReduce): class Min(NonZeroDimsCAReduce):
nfunc_spec = ("min", 1, 1) nfunc_spec = ("min", 1, 1)
def __init__(self, axis): def __init__(self, axis):
...@@ -1496,7 +1501,7 @@ def complex_from_polar(abs, angle): ...@@ -1496,7 +1501,7 @@ def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification.""" """Return complex-valued tensor from polar coordinate specification."""
class Mean(CAReduce): class Mean(FixedOpCAReduce):
__props__ = ("axis",) __props__ = ("axis",)
nfunc_spec = ("mean", 1, 1) nfunc_spec = ("mean", 1, 1)
...@@ -2356,7 +2361,7 @@ def outer(x, y): ...@@ -2356,7 +2361,7 @@ def outer(x, y):
return dot(x.dimshuffle(0, "x"), y.dimshuffle("x", 0)) return dot(x.dimshuffle(0, "x"), y.dimshuffle("x", 0))
class All(CAReduce): class All(FixedOpCAReduce):
"""Applies `logical and` to all the values of a tensor along the """Applies `logical and` to all the values of a tensor along the
specified axis(es). specified axis(es).
...@@ -2370,12 +2375,6 @@ class All(CAReduce): ...@@ -2370,12 +2375,6 @@ class All(CAReduce):
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
return "bool" return "bool"
def __str__(self):
if self.axis is None:
return "All"
else:
return "All{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input): def make_node(self, input):
input = as_tensor_variable(input) input = as_tensor_variable(input)
if input.dtype != "bool": if input.dtype != "bool":
...@@ -2392,7 +2391,7 @@ class All(CAReduce): ...@@ -2392,7 +2391,7 @@ class All(CAReduce):
return type(self)(axis=axis) return type(self)(axis=axis)
class Any(CAReduce): class Any(FixedOpCAReduce):
"""Applies `bitwise or` to all the values of a tensor along the """Applies `bitwise or` to all the values of a tensor along the
specified axis(es). specified axis(es).
...@@ -2406,12 +2405,6 @@ class Any(CAReduce): ...@@ -2406,12 +2405,6 @@ class Any(CAReduce):
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
return "bool" return "bool"
def __str__(self):
if self.axis is None:
return "Any"
else:
return "Any{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input): def make_node(self, input):
input = as_tensor_variable(input) input = as_tensor_variable(input)
if input.dtype != "bool": if input.dtype != "bool":
...@@ -2428,7 +2421,7 @@ class Any(CAReduce): ...@@ -2428,7 +2421,7 @@ class Any(CAReduce):
return type(self)(axis=axis) return type(self)(axis=axis)
class Sum(CAReduce): class Sum(FixedOpCAReduce):
""" """
Sums all the values of a tensor along the specified axis(es). Sums all the values of a tensor along the specified axis(es).
...@@ -2449,14 +2442,6 @@ class Sum(CAReduce): ...@@ -2449,14 +2442,6 @@ class Sum(CAReduce):
upcast_discrete_output=True, upcast_discrete_output=True,
) )
def __str__(self):
name = self.__class__.__name__
axis = ""
if self.axis is not None:
axis = ", ".join(str(x) for x in self.axis)
axis = f"axis=[{axis}], "
return f"{name}{{{axis}acc_dtype={self.acc_dtype}}}"
def L_op(self, inp, out, grads): def L_op(self, inp, out, grads):
(x,) = inp (x,) = inp
...@@ -2526,7 +2511,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): ...@@ -2526,7 +2511,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
pprint.assign(Sum, printing.FunctionPrinter(["sum"], ["axis"])) pprint.assign(Sum, printing.FunctionPrinter(["sum"], ["axis"]))
class Prod(CAReduce): class Prod(FixedOpCAReduce):
""" """
Multiplies all the values of a tensor along the specified axis(es). Multiplies all the values of a tensor along the specified axis(es).
...@@ -2537,7 +2522,6 @@ class Prod(CAReduce): ...@@ -2537,7 +2522,6 @@ class Prod(CAReduce):
""" """
__props__ = ("scalar_op", "axis", "dtype", "acc_dtype", "no_zeros_in_input") __props__ = ("scalar_op", "axis", "dtype", "acc_dtype", "no_zeros_in_input")
nfunc_spec = ("prod", 1, 1) nfunc_spec = ("prod", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False): def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False):
...@@ -2683,6 +2667,14 @@ class Prod(CAReduce): ...@@ -2683,6 +2667,14 @@ class Prod(CAReduce):
no_zeros_in_input=no_zeros_in_input, no_zeros_in_input=no_zeros_in_input,
) )
def __str__(self):
if self.no_zeros_in_input:
return f"{super().__str__()[:-1]}, no_zeros_in_input}})"
return super().__str__()
def __repr__(self):
return f"{super().__repr__()[:-1]}, no_zeros_in_input={self.no_zeros_in_input})"
def prod( def prod(
input, input,
...@@ -2751,7 +2743,7 @@ class MulWithoutZeros(BinaryScalarOp): ...@@ -2751,7 +2743,7 @@ class MulWithoutZeros(BinaryScalarOp):
mul_without_zeros = MulWithoutZeros(aes.upcast_out, name="mul_without_zeros") mul_without_zeros = MulWithoutZeros(aes.upcast_out, name="mul_without_zeros")
class ProdWithoutZeros(CAReduce): class ProdWithoutZeros(FixedOpCAReduce):
def __init__(self, axis=None, dtype=None, acc_dtype=None): def __init__(self, axis=None, dtype=None, acc_dtype=None):
super().__init__( super().__init__(
mul_without_zeros, mul_without_zeros,
......
...@@ -42,7 +42,8 @@ from pytensor.tensor.math import ( ...@@ -42,7 +42,8 @@ from pytensor.tensor.math import (
All, All,
Any, Any,
Dot, Dot,
NonZeroCAReduce, FixedOpCAReduce,
NonZeroDimsCAReduce,
Prod, Prod,
ProdWithoutZeros, ProdWithoutZeros,
Sum, Sum,
...@@ -1671,7 +1672,8 @@ ALL_REDUCE = ( ...@@ -1671,7 +1672,8 @@ ALL_REDUCE = (
ProdWithoutZeros, ProdWithoutZeros,
] ]
+ CAReduce.__subclasses__() + CAReduce.__subclasses__()
+ NonZeroCAReduce.__subclasses__() + FixedOpCAReduce.__subclasses__()
+ NonZeroDimsCAReduce.__subclasses__()
) )
......
...@@ -579,9 +579,9 @@ def test_debugprint(): ...@@ -579,9 +579,9 @@ def test_debugprint():
Inner graphs: Inner graphs:
OpFromGraph{inline=False} [id A] OpFromGraph{inline=False} [id A]
Elemwise{add,no_inplace} [id E] Add [id E]
├─ *0-<TensorType(float64, (?, ?))> [id F] ├─ *0-<TensorType(float64, (?, ?))> [id F]
└─ Elemwise{mul,no_inplace} [id G] └─ Mul [id G]
├─ *1-<TensorType(float64, (?, ?))> [id H] ├─ *1-<TensorType(float64, (?, ?))> [id H]
└─ *2-<TensorType(float64, (?, ?))> [id I] └─ *2-<TensorType(float64, (?, ?))> [id I]
""" """
......
...@@ -156,7 +156,7 @@ def test_fgraph_to_python_multiline_str(): ...@@ -156,7 +156,7 @@ def test_fgraph_to_python_multiline_str():
assert ( assert (
""" """
# Elemwise{add,no_inplace}(Test # Add(Test
# Op().0, Test # Op().0, Test
# Op().1) # Op().1)
""" """
......
...@@ -676,14 +676,9 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -676,14 +676,9 @@ class TestCAReduce(unittest_tools.InferShapeTester):
def test_str(self): def test_str(self):
op = CAReduce(aes.add, axis=None) op = CAReduce(aes.add, axis=None)
assert str(op) == "CAReduce{add}" assert str(op) == "CAReduce{add, axes=None}"
op = CAReduce(aes.add, axis=(1,)) op = CAReduce(aes.add, axis=(1,))
assert str(op) == "CAReduce{add}{axis=[1]}" assert str(op) == "CAReduce{add, axis=1}"
op = CAReduce(aes.add, axis=None, acc_dtype="float64")
assert str(op) == "CAReduce{add}{acc_dtype=float64}"
op = CAReduce(aes.add, axis=(1,), acc_dtype="float64")
assert str(op) == "CAReduce{add}{axis=[1], acc_dtype=float64}"
def test_repeated_axis(self): def test_repeated_axis(self):
x = vector("x") x = vector("x")
...@@ -802,10 +797,8 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -802,10 +797,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
self.check_input_dimensions_match(Mode(linker="c")) self.check_input_dimensions_match(Mode(linker="c"))
def test_str(self): def test_str(self):
op = Elemwise(aes.add, inplace_pattern=None, name=None)
assert str(op) == "Elemwise{add}"
op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None) op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None)
assert str(op) == "Elemwise{add}[(0, 0)]" assert str(op) == "Add"
op = Elemwise(aes.add, inplace_pattern=None, name="my_op") op = Elemwise(aes.add, inplace_pattern=None, name="my_op")
assert str(op) == "my_op" assert str(op) == "my_op"
......
...@@ -106,9 +106,9 @@ def test_min_informative_str(): ...@@ -106,9 +106,9 @@ def test_min_informative_str():
mis = min_informative_str(G).replace("\t", " ") mis = min_informative_str(G).replace("\t", " ")
reference = """A. Elemwise{add,no_inplace} reference = """A. Add
B. C B. C
C. Elemwise{add,no_inplace} C. Add
D. D D. D
E. E""" E. E"""
...@@ -144,11 +144,11 @@ def test_debugprint(): ...@@ -144,11 +144,11 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} [id 0] Add [id 0]
├─ Elemwise{add,no_inplace} [id 1] 'C' ├─ Add [id 1] 'C'
│ ├─ A [id 2] │ ├─ A [id 2]
│ └─ B [id 3] │ └─ B [id 3]
└─ Elemwise{add,no_inplace} [id 4] └─ Add [id 4]
├─ D [id 5] ├─ D [id 5]
└─ E [id 6] └─ E [id 6]
""" """
...@@ -162,11 +162,11 @@ def test_debugprint(): ...@@ -162,11 +162,11 @@ def test_debugprint():
# The additional white space are needed! # The additional white space are needed!
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} [id A] Add [id A]
├─ Elemwise{add,no_inplace} [id B] 'C' ├─ Add [id B] 'C'
│ ├─ A [id C] │ ├─ A [id C]
│ └─ B [id D] │ └─ B [id D]
└─ Elemwise{add,no_inplace} [id E] └─ Add [id E]
├─ D [id F] ├─ D [id F]
└─ E [id G] └─ E [id G]
""" """
...@@ -180,10 +180,10 @@ def test_debugprint(): ...@@ -180,10 +180,10 @@ def test_debugprint():
# The additional white space are needed! # The additional white space are needed!
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} [id A] Add [id A]
├─ Elemwise{add,no_inplace} [id B] 'C' ├─ Add [id B] 'C'
│ └─ ··· │ └─ ···
└─ Elemwise{add,no_inplace} [id C] └─ Add [id C]
├─ D [id D] ├─ D [id D]
└─ E [id E] └─ E [id E]
""" """
...@@ -196,11 +196,11 @@ def test_debugprint(): ...@@ -196,11 +196,11 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} Add
├─ Elemwise{add,no_inplace} 'C' ├─ Add 'C'
│ ├─ A │ ├─ A
│ └─ B │ └─ B
└─ Elemwise{add,no_inplace} └─ Add
├─ D ├─ D
└─ E └─ E
""" """
...@@ -213,7 +213,7 @@ def test_debugprint(): ...@@ -213,7 +213,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} 0 [None] Add 0 [None]
├─ A [None] ├─ A [None]
├─ B [None] ├─ B [None]
├─ D [None] ├─ D [None]
...@@ -231,7 +231,7 @@ def test_debugprint(): ...@@ -231,7 +231,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} 0 [None] Add 0 [None]
├─ A [None] ├─ A [None]
├─ B [None] ├─ B [None]
├─ D [None] ├─ D [None]
...@@ -249,7 +249,7 @@ def test_debugprint(): ...@@ -249,7 +249,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} 0 [None] Add 0 [None]
├─ A [None] ├─ A [None]
├─ B [None] ├─ B [None]
├─ D [None] ├─ D [None]
...@@ -274,7 +274,7 @@ def test_debugprint(): ...@@ -274,7 +274,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
exp_res = dedent( exp_res = dedent(
r""" r"""
Elemwise{Composite{(i2 + (i0 - i1))}} 4 Composite{(i2 + (i0 - i1))} 4
├─ ExpandDims{axis=0} v={0: [0]} 3 ├─ ExpandDims{axis=0} v={0: [0]} 3
│ └─ CGemv{inplace} d={0: [0]} 2 │ └─ CGemv{inplace} d={0: [0]} 2
│ ├─ AllocEmpty{dtype='float64'} 1 │ ├─ AllocEmpty{dtype='float64'} 1
...@@ -289,7 +289,7 @@ def test_debugprint(): ...@@ -289,7 +289,7 @@ def test_debugprint():
Inner graphs: Inner graphs:
Elemwise{Composite{(i2 + (i0 - i1))}} Composite{(i2 + (i0 - i1))}
← add 'o0' ← add 'o0'
├─ i2 ├─ i2
└─ sub └─ sub
...@@ -314,7 +314,7 @@ def test_debugprint_id_type(): ...@@ -314,7 +314,7 @@ def test_debugprint_id_type():
debugprint(e_at, id_type="auto", file=s) debugprint(e_at, id_type="auto", file=s)
s = s.getvalue() s = s.getvalue()
exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}] exp_res = f"""Add [id {e_at.auto_name}]
├─ dot [id {d_at.auto_name}] ├─ dot [id {d_at.auto_name}]
│ ├─ <TensorType(float64, (?, ?))> [id {b_at.auto_name}] │ ├─ <TensorType(float64, (?, ?))> [id {b_at.auto_name}]
│ └─ <TensorType(float64, (?,))> [id {a_at.auto_name}] │ └─ <TensorType(float64, (?,))> [id {a_at.auto_name}]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论