提交 c8922630 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Use type name in Elemwise and CAReduce __str__ implementations

上级 9023e2b3
...@@ -474,9 +474,9 @@ second dimension ...@@ -474,9 +474,9 @@ second dimension
if self.inplace_pattern: if self.inplace_pattern:
items = list(self.inplace_pattern.items()) items = list(self.inplace_pattern.items())
items.sort() items.sort()
return f"Elemwise{{{self.scalar_op}}}{items}" return f"{type(self).__name__}{{{self.scalar_op}}}{items}"
else: else:
return "Elemwise{%s}" % (self.scalar_op) return f"{type(self).__name__}{{{self.scalar_op}}}"
else: else:
return self.name return self.name
...@@ -1340,13 +1340,12 @@ class CAReduce(COp): ...@@ -1340,13 +1340,12 @@ class CAReduce(COp):
self.set_ufunc(self.scalar_op) self.set_ufunc(self.scalar_op)
def __str__(self): def __str__(self):
prefix = f"{type(self).__name__}{{{self.scalar_op}}}"
if self.axis is not None: if self.axis is not None:
return "Reduce{{{}}}{{{}}}".format( axes_str = ", ".join(str(x) for x in self.axis)
self.scalar_op, return f"{prefix}{{{axes_str}}}"
", ".join(str(x) for x in self.axis),
)
else: else:
return "Reduce{%s}" % self.scalar_op return f"{prefix}"
def perform(self, node, inp, out): def perform(self, node, inp, out):
(input,) = inp (input,) = inp
...@@ -1750,14 +1749,12 @@ class CAReduceDtype(CAReduce): ...@@ -1750,14 +1749,12 @@ class CAReduceDtype(CAReduce):
return super(CAReduceDtype, op).make_node(input) return super(CAReduceDtype, op).make_node(input)
def __str__(self): def __str__(self):
name = self.__class__.__name__ prefix = f"{type(self).__name__}{{{self.scalar_op}}}"
if self.__class__.__name__ == "CAReduceDtype":
name = ("ReduceDtype{%s}" % self.scalar_op,)
axis = ""
if self.axis is not None: if self.axis is not None:
axis = ", ".join(str(x) for x in self.axis) axis = ", ".join(str(x) for x in self.axis)
axis = f"axis=[{axis}], " return f"{prefix}{{axis=[{axis}], acc_dtype={self.acc_dtype}}}"
return f"{name}{{{axis}acc_dtype={self.acc_dtype}}}" else:
return f"{prefix}{{acc_dtype={self.acc_dtype}}}"
def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
......
...@@ -1578,7 +1578,7 @@ ALL_REDUCE = ( ...@@ -1578,7 +1578,7 @@ ALL_REDUCE = (
@local_optimizer(ALL_REDUCE) @local_optimizer(ALL_REDUCE)
def local_reduce_join(fgraph, node): def local_reduce_join(fgraph, node):
""" """
Reduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b) CAReduce{scalar.op}(Join(axis=0, a, b), axis=0) -> Elemwise{scalar.op}(a, b)
Notes Notes
----- -----
......
...@@ -16,7 +16,7 @@ from aesara.link.basic import PerformLinker ...@@ -16,7 +16,7 @@ from aesara.link.basic import PerformLinker
from aesara.link.c.basic import CLinker, OpWiseCLinker from aesara.link.c.basic import CLinker, OpWiseCLinker
from aesara.tensor import as_tensor_variable from aesara.tensor import as_tensor_variable
from aesara.tensor.basic import second from aesara.tensor.basic import second
from aesara.tensor.elemwise import CAReduce, DimShuffle, Elemwise from aesara.tensor.elemwise import CAReduce, CAReduceDtype, DimShuffle, Elemwise
from aesara.tensor.math import all as at_all from aesara.tensor.math import all as at_all
from aesara.tensor.math import any as at_any from aesara.tensor.math import any as at_any
from aesara.tensor.type import ( from aesara.tensor.type import (
...@@ -622,6 +622,17 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -622,6 +622,17 @@ class TestCAReduce(unittest_tools.InferShapeTester):
warn=0 not in xsh, warn=0 not in xsh,
) )
def test_str(self):
op = CAReduce(aes.add, axis=None)
assert str(op) == "CAReduce{add}"
op = CAReduce(aes.add, axis=(1,))
assert str(op) == "CAReduce{add}{1}"
op = CAReduceDtype(aes.add, axis=None, acc_dtype="float64")
assert str(op) == "CAReduceDtype{add}{acc_dtype=float64}"
op = CAReduceDtype(aes.add, axis=(1,), acc_dtype="float64")
assert str(op) == "CAReduceDtype{add}{axis=[1], acc_dtype=float64}"
class TestBitOpReduceGrad: class TestBitOpReduceGrad:
def setup_method(self): def setup_method(self):
...@@ -722,6 +733,14 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -722,6 +733,14 @@ class TestElemwise(unittest_tools.InferShapeTester):
def test_input_dimensions_match_c(self): def test_input_dimensions_match_c(self):
self.check_input_dimensions_match(Mode(linker="c")) self.check_input_dimensions_match(Mode(linker="c"))
def test_str(self):
op = Elemwise(aes.add, inplace_pattern=None, name=None)
assert str(op) == "Elemwise{add}"
op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None)
assert str(op) == "Elemwise{add}[(0, 0)]"
op = Elemwise(aes.add, inplace_pattern=None, name="my_op")
assert str(op) == "my_op"
def test_not_implemented_elemwise_grad(): def test_not_implemented_elemwise_grad():
# Regression test for unimplemented gradient in an Elemwise Op. # Regression test for unimplemented gradient in an Elemwise Op.
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论