提交 b74cf3f6 authored 作者: Ricardo Vieira's avatar Ricardo Vieira 提交者: Ricardo Vieira

Simplify string representation of Elemwise and CAReduce

上级 5841c30e
...@@ -234,6 +234,7 @@ class MetaType(ABCMeta): ...@@ -234,6 +234,7 @@ class MetaType(ABCMeta):
dct["__eq__"] = __eq__ dct["__eq__"] = __eq__
# FIXME: This overrides __str__ inheritance when props are provided
if "__str__" not in dct: if "__str__" not in dct:
if len(props) == 0: if len(props) == 0:
......
...@@ -14,7 +14,7 @@ from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp ...@@ -14,7 +14,7 @@ from pytensor.link.c.op import COp, ExternalCOp, OpenMPOp
from pytensor.link.c.params_type import ParamsType from pytensor.link.c.params_type import ParamsType
from pytensor.misc.frozendict import frozendict from pytensor.misc.frozendict import frozendict
from pytensor.misc.safe_asarray import _asarray from pytensor.misc.safe_asarray import _asarray
from pytensor.printing import FunctionPrinter, Printer, pprint from pytensor.printing import Printer, pprint
from pytensor.scalar import get_scalar_type from pytensor.scalar import get_scalar_type
from pytensor.scalar.basic import bool as scalar_bool from pytensor.scalar.basic import bool as scalar_bool
from pytensor.scalar.basic import identity as scalar_identity from pytensor.scalar.basic import identity as scalar_identity
...@@ -498,15 +498,9 @@ class Elemwise(OpenMPOp): ...@@ -498,15 +498,9 @@ class Elemwise(OpenMPOp):
return Apply(self, inputs, outputs) return Apply(self, inputs, outputs)
def __str__(self): def __str__(self):
if self.name is None: if self.name:
if self.inplace_pattern:
items = list(self.inplace_pattern.items())
items.sort()
return f"{type(self).__name__}{{{self.scalar_op}}}{items}"
else:
return f"{type(self).__name__}{{{self.scalar_op}}}"
else:
return self.name return self.name
return str(self.scalar_op).capitalize()
def R_op(self, inputs, eval_points): def R_op(self, inputs, eval_points):
outs = self(*inputs, return_list=True) outs = self(*inputs, return_list=True)
...@@ -1477,23 +1471,17 @@ class CAReduce(COp): ...@@ -1477,23 +1471,17 @@ class CAReduce(COp):
return res return res
def __str__(self): def _axis_str(self):
prefix = f"{type(self).__name__}{{{self.scalar_op}}}" axis = self.axis
extra_params = [] if axis is None:
return "axes=None"
if self.axis is not None: elif len(axis) == 1:
axis = ", ".join(str(x) for x in self.axis) return f"axis={axis[0]}"
extra_params.append(f"axis=[{axis}]")
if self.acc_dtype:
extra_params.append(f"acc_dtype={self.acc_dtype}")
extra_params_str = ", ".join(extra_params)
if extra_params_str:
return f"{prefix}{{{extra_params_str}}}"
else: else:
return f"{prefix}" return f"axes={list(axis)}"
def __str__(self):
return f"{type(self).__name__}{{{self.scalar_op}, {self._axis_str()}}}"
def perform(self, node, inp, out): def perform(self, node, inp, out):
(input,) = inp (input,) = inp
...@@ -1737,21 +1725,17 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): ...@@ -1737,21 +1725,17 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
symbolname = symbolname or symbol.__name__ symbolname = symbolname or symbol.__name__
if symbolname.endswith("_inplace"): if symbolname.endswith("_inplace"):
elemwise_name = f"Elemwise{{{symbolname},inplace}}" base_symbol_name = symbolname[: -len("_inplace")]
scalar_op = getattr(scalar, symbolname[: -len("_inplace")]) scalar_op = getattr(scalar, base_symbol_name)
inplace_scalar_op = scalar_op.__class__(transfer_type(0)) inplace_scalar_op = scalar_op.__class__(transfer_type(0))
rval = Elemwise( rval = Elemwise(
inplace_scalar_op, inplace_scalar_op,
{0: 0}, {0: 0},
name=elemwise_name,
nfunc_spec=(nfunc and (nfunc, nin, nout)), nfunc_spec=(nfunc and (nfunc, nin, nout)),
) )
else: else:
elemwise_name = f"Elemwise{{{symbolname},no_inplace}}"
scalar_op = getattr(scalar, symbolname) scalar_op = getattr(scalar, symbolname)
rval = Elemwise( rval = Elemwise(scalar_op, nfunc_spec=(nfunc and (nfunc, nin, nout)))
scalar_op, name=elemwise_name, nfunc_spec=(nfunc and (nfunc, nin, nout))
)
if getattr(symbol, "__doc__"): if getattr(symbol, "__doc__"):
rval.__doc__ = symbol.__doc__ + "\n\n " + rval.__doc__ rval.__doc__ = symbol.__doc__ + "\n\n " + rval.__doc__
...@@ -1761,8 +1745,6 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None): ...@@ -1761,8 +1745,6 @@ def scalar_elemwise(*symbol, nfunc=None, nin=None, nout=None, symbolname=None):
rval.__epydoc_asRoutine = symbol rval.__epydoc_asRoutine = symbol
rval.__module__ = symbol.__module__ rval.__module__ = symbol.__module__
pprint.assign(rval, FunctionPrinter([symbolname.replace("_inplace", "=")]))
return rval return rval
if symbol: if symbol:
......
...@@ -583,7 +583,12 @@ def max_and_argmax(a, axis=None, keepdims=False): ...@@ -583,7 +583,12 @@ def max_and_argmax(a, axis=None, keepdims=False):
return [out, argout] return [out, argout]
class NonZeroCAReduce(CAReduce): class FixedOpCAReduce(CAReduce):
def __str__(self):
return f"{type(self).__name__}{{{self._axis_str()}}}"
class NonZeroDimsCAReduce(FixedOpCAReduce):
def _c_all(self, node, name, inames, onames, sub): def _c_all(self, node, name, inames, onames, sub):
decl, checks, alloc, loop, end = super()._c_all(node, name, inames, onames, sub) decl, checks, alloc, loop, end = super()._c_all(node, name, inames, onames, sub)
...@@ -614,7 +619,7 @@ class NonZeroCAReduce(CAReduce): ...@@ -614,7 +619,7 @@ class NonZeroCAReduce(CAReduce):
return decl, checks, alloc, loop, end return decl, checks, alloc, loop, end
class Max(NonZeroCAReduce): class Max(NonZeroDimsCAReduce):
nfunc_spec = ("max", 1, 1) nfunc_spec = ("max", 1, 1)
def __init__(self, axis): def __init__(self, axis):
...@@ -625,7 +630,7 @@ class Max(NonZeroCAReduce): ...@@ -625,7 +630,7 @@ class Max(NonZeroCAReduce):
return type(self)(axis=axis) return type(self)(axis=axis)
class Min(NonZeroCAReduce): class Min(NonZeroDimsCAReduce):
nfunc_spec = ("min", 1, 1) nfunc_spec = ("min", 1, 1)
def __init__(self, axis): def __init__(self, axis):
...@@ -1496,7 +1501,7 @@ def complex_from_polar(abs, angle): ...@@ -1496,7 +1501,7 @@ def complex_from_polar(abs, angle):
"""Return complex-valued tensor from polar coordinate specification.""" """Return complex-valued tensor from polar coordinate specification."""
class Mean(CAReduce): class Mean(FixedOpCAReduce):
__props__ = ("axis",) __props__ = ("axis",)
nfunc_spec = ("mean", 1, 1) nfunc_spec = ("mean", 1, 1)
...@@ -2356,7 +2361,7 @@ def outer(x, y): ...@@ -2356,7 +2361,7 @@ def outer(x, y):
return dot(x.dimshuffle(0, "x"), y.dimshuffle("x", 0)) return dot(x.dimshuffle(0, "x"), y.dimshuffle("x", 0))
class All(CAReduce): class All(FixedOpCAReduce):
"""Applies `logical and` to all the values of a tensor along the """Applies `logical and` to all the values of a tensor along the
specified axis(es). specified axis(es).
...@@ -2370,12 +2375,6 @@ class All(CAReduce): ...@@ -2370,12 +2375,6 @@ class All(CAReduce):
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
return "bool" return "bool"
def __str__(self):
if self.axis is None:
return "All"
else:
return "All{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input): def make_node(self, input):
input = as_tensor_variable(input) input = as_tensor_variable(input)
if input.dtype != "bool": if input.dtype != "bool":
...@@ -2392,7 +2391,7 @@ class All(CAReduce): ...@@ -2392,7 +2391,7 @@ class All(CAReduce):
return type(self)(axis=axis) return type(self)(axis=axis)
class Any(CAReduce): class Any(FixedOpCAReduce):
"""Applies `bitwise or` to all the values of a tensor along the """Applies `bitwise or` to all the values of a tensor along the
specified axis(es). specified axis(es).
...@@ -2406,12 +2405,6 @@ class Any(CAReduce): ...@@ -2406,12 +2405,6 @@ class Any(CAReduce):
def _output_dtype(self, idtype): def _output_dtype(self, idtype):
return "bool" return "bool"
def __str__(self):
if self.axis is None:
return "Any"
else:
return "Any{%s}" % ", ".join(map(str, self.axis))
def make_node(self, input): def make_node(self, input):
input = as_tensor_variable(input) input = as_tensor_variable(input)
if input.dtype != "bool": if input.dtype != "bool":
...@@ -2428,7 +2421,7 @@ class Any(CAReduce): ...@@ -2428,7 +2421,7 @@ class Any(CAReduce):
return type(self)(axis=axis) return type(self)(axis=axis)
class Sum(CAReduce): class Sum(FixedOpCAReduce):
""" """
Sums all the values of a tensor along the specified axis(es). Sums all the values of a tensor along the specified axis(es).
...@@ -2449,14 +2442,6 @@ class Sum(CAReduce): ...@@ -2449,14 +2442,6 @@ class Sum(CAReduce):
upcast_discrete_output=True, upcast_discrete_output=True,
) )
def __str__(self):
name = self.__class__.__name__
axis = ""
if self.axis is not None:
axis = ", ".join(str(x) for x in self.axis)
axis = f"axis=[{axis}], "
return f"{name}{{{axis}acc_dtype={self.acc_dtype}}}"
def L_op(self, inp, out, grads): def L_op(self, inp, out, grads):
(x,) = inp (x,) = inp
...@@ -2526,7 +2511,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None): ...@@ -2526,7 +2511,7 @@ def sum(input, axis=None, dtype=None, keepdims=False, acc_dtype=None):
pprint.assign(Sum, printing.FunctionPrinter(["sum"], ["axis"])) pprint.assign(Sum, printing.FunctionPrinter(["sum"], ["axis"]))
class Prod(CAReduce): class Prod(FixedOpCAReduce):
""" """
Multiplies all the values of a tensor along the specified axis(es). Multiplies all the values of a tensor along the specified axis(es).
...@@ -2537,7 +2522,6 @@ class Prod(CAReduce): ...@@ -2537,7 +2522,6 @@ class Prod(CAReduce):
""" """
__props__ = ("scalar_op", "axis", "dtype", "acc_dtype", "no_zeros_in_input") __props__ = ("scalar_op", "axis", "dtype", "acc_dtype", "no_zeros_in_input")
nfunc_spec = ("prod", 1, 1) nfunc_spec = ("prod", 1, 1)
def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False): def __init__(self, axis=None, dtype=None, acc_dtype=None, no_zeros_in_input=False):
...@@ -2683,6 +2667,14 @@ class Prod(CAReduce): ...@@ -2683,6 +2667,14 @@ class Prod(CAReduce):
no_zeros_in_input=no_zeros_in_input, no_zeros_in_input=no_zeros_in_input,
) )
def __str__(self):
if self.no_zeros_in_input:
return f"{super().__str__()[:-1]}, no_zeros_in_input}})"
return super().__str__()
def __repr__(self):
return f"{super().__repr__()[:-1]}, no_zeros_in_input={self.no_zeros_in_input})"
def prod( def prod(
input, input,
...@@ -2751,7 +2743,7 @@ class MulWithoutZeros(BinaryScalarOp): ...@@ -2751,7 +2743,7 @@ class MulWithoutZeros(BinaryScalarOp):
mul_without_zeros = MulWithoutZeros(aes.upcast_out, name="mul_without_zeros") mul_without_zeros = MulWithoutZeros(aes.upcast_out, name="mul_without_zeros")
class ProdWithoutZeros(CAReduce): class ProdWithoutZeros(FixedOpCAReduce):
def __init__(self, axis=None, dtype=None, acc_dtype=None): def __init__(self, axis=None, dtype=None, acc_dtype=None):
super().__init__( super().__init__(
mul_without_zeros, mul_without_zeros,
......
...@@ -42,7 +42,8 @@ from pytensor.tensor.math import ( ...@@ -42,7 +42,8 @@ from pytensor.tensor.math import (
All, All,
Any, Any,
Dot, Dot,
NonZeroCAReduce, FixedOpCAReduce,
NonZeroDimsCAReduce,
Prod, Prod,
ProdWithoutZeros, ProdWithoutZeros,
Sum, Sum,
...@@ -1671,7 +1672,8 @@ ALL_REDUCE = ( ...@@ -1671,7 +1672,8 @@ ALL_REDUCE = (
ProdWithoutZeros, ProdWithoutZeros,
] ]
+ CAReduce.__subclasses__() + CAReduce.__subclasses__()
+ NonZeroCAReduce.__subclasses__() + FixedOpCAReduce.__subclasses__()
+ NonZeroDimsCAReduce.__subclasses__()
) )
......
...@@ -579,9 +579,9 @@ def test_debugprint(): ...@@ -579,9 +579,9 @@ def test_debugprint():
Inner graphs: Inner graphs:
OpFromGraph{inline=False} [id A] OpFromGraph{inline=False} [id A]
Elemwise{add,no_inplace} [id E] Add [id E]
├─ *0-<TensorType(float64, (?, ?))> [id F] ├─ *0-<TensorType(float64, (?, ?))> [id F]
└─ Elemwise{mul,no_inplace} [id G] └─ Mul [id G]
├─ *1-<TensorType(float64, (?, ?))> [id H] ├─ *1-<TensorType(float64, (?, ?))> [id H]
└─ *2-<TensorType(float64, (?, ?))> [id I] └─ *2-<TensorType(float64, (?, ?))> [id I]
""" """
......
...@@ -156,7 +156,7 @@ def test_fgraph_to_python_multiline_str(): ...@@ -156,7 +156,7 @@ def test_fgraph_to_python_multiline_str():
assert ( assert (
""" """
# Elemwise{add,no_inplace}(Test # Add(Test
# Op().0, Test # Op().0, Test
# Op().1) # Op().1)
""" """
......
...@@ -32,13 +32,13 @@ def test_debugprint_sitsot(): ...@@ -32,13 +32,13 @@ def test_debugprint_sitsot():
│ │ ├─ k [id D] (n_steps) │ │ ├─ k [id D] (n_steps)
│ │ ├─ IncSubtensor{Set;:int64:} [id E] (outer_in_sit_sot-0) │ │ ├─ IncSubtensor{Set;:int64:} [id E] (outer_in_sit_sot-0)
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F] │ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
│ │ │ │ ├─ Elemwise{add,no_inplace} [id G] │ │ │ │ ├─ Add [id G]
│ │ │ │ │ ├─ k [id D] │ │ │ │ │ ├─ k [id D]
│ │ │ │ │ └─ Subtensor{int64} [id H] │ │ │ │ │ └─ Subtensor{int64} [id H]
│ │ │ │ │ ├─ Shape [id I] │ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J] │ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
│ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id L] │ │ │ │ │ │ └─ Second [id L]
│ │ │ │ │ │ ├─ A [id M] │ │ │ │ │ │ ├─ A [id M]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O] │ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
...@@ -60,7 +60,7 @@ def test_debugprint_sitsot(): ...@@ -60,7 +60,7 @@ def test_debugprint_sitsot():
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id C] for{cpu,scan_fn} [id C]
Elemwise{mul,no_inplace} [id W] (inner_out_sit_sot-0) Mul [id W] (inner_out_sit_sot-0)
├─ *0-<TensorType(float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0) ├─ *0-<TensorType(float64, (?,))> [id X] -> [id E] (inner_in_sit_sot-0)
└─ *1-<TensorType(float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)""" └─ *1-<TensorType(float64, (?,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
...@@ -90,13 +90,13 @@ def test_debugprint_sitsot_no_extra_info(): ...@@ -90,13 +90,13 @@ def test_debugprint_sitsot_no_extra_info():
│ │ ├─ k [id D] │ │ ├─ k [id D]
│ │ ├─ IncSubtensor{Set;:int64:} [id E] │ │ ├─ IncSubtensor{Set;:int64:} [id E]
│ │ │ ├─ AllocEmpty{dtype='float64'} [id F] │ │ │ ├─ AllocEmpty{dtype='float64'} [id F]
│ │ │ │ ├─ Elemwise{add,no_inplace} [id G] │ │ │ │ ├─ Add [id G]
│ │ │ │ │ ├─ k [id D] │ │ │ │ │ ├─ k [id D]
│ │ │ │ │ └─ Subtensor{int64} [id H] │ │ │ │ │ └─ Subtensor{int64} [id H]
│ │ │ │ │ ├─ Shape [id I] │ │ │ │ │ ├─ Shape [id I]
│ │ │ │ │ │ └─ Unbroadcast{0} [id J] │ │ │ │ │ │ └─ Unbroadcast{0} [id J]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id K] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id K]
│ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id L] │ │ │ │ │ │ └─ Second [id L]
│ │ │ │ │ │ ├─ A [id M] │ │ │ │ │ │ ├─ A [id M]
│ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ └─ TensorConstant{1.0} [id O] │ │ │ │ │ │ └─ TensorConstant{1.0} [id O]
...@@ -118,7 +118,7 @@ def test_debugprint_sitsot_no_extra_info(): ...@@ -118,7 +118,7 @@ def test_debugprint_sitsot_no_extra_info():
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id C] for{cpu,scan_fn} [id C]
Elemwise{mul,no_inplace} [id W] Mul [id W]
├─ *0-<TensorType(float64, (?,))> [id X] -> [id E] ├─ *0-<TensorType(float64, (?,))> [id X] -> [id E]
└─ *1-<TensorType(float64, (?,))> [id Y] -> [id M]""" └─ *1-<TensorType(float64, (?,))> [id Y] -> [id M]"""
...@@ -147,9 +147,9 @@ def test_debugprint_nitsot(): ...@@ -147,9 +147,9 @@ def test_debugprint_nitsot():
output_str = debugprint(polynomial, file="str", print_op_info=True) output_str = debugprint(polynomial, file="str", print_op_info=True)
lines = output_str.split("\n") lines = output_str.split("\n")
expected_output = """Sum{acc_dtype=float64} [id A] expected_output = """Sum{axes=None} [id A]
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0) └─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
├─ Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0) ├─ Minimum [id C] (outer_in_nit_sot-0)
│ ├─ Subtensor{int64} [id D] │ ├─ Subtensor{int64} [id D]
│ │ ├─ Shape [id E] │ │ ├─ Shape [id E]
│ │ │ └─ Subtensor{int64::} [id F] 'coefficients[0:]' │ │ │ └─ Subtensor{int64::} [id F] 'coefficients[0:]'
...@@ -169,24 +169,24 @@ def test_debugprint_nitsot(): ...@@ -169,24 +169,24 @@ def test_debugprint_nitsot():
│ ├─ Subtensor{int64::} [id F] 'coefficients[0:]' │ ├─ Subtensor{int64::} [id F] 'coefficients[0:]'
│ │ └─ ··· │ │ └─ ···
│ └─ ScalarFromTensor [id T] │ └─ ScalarFromTensor [id T]
│ └─ Elemwise{scalar_minimum,no_inplace} [id C] │ └─ Minimum [id C]
│ └─ ··· │ └─ ···
├─ Subtensor{:int64:} [id U] (outer_in_seqs-1) ├─ Subtensor{:int64:} [id U] (outer_in_seqs-1)
│ ├─ Subtensor{int64::} [id L] │ ├─ Subtensor{int64::} [id L]
│ │ └─ ··· │ │ └─ ···
│ └─ ScalarFromTensor [id V] │ └─ ScalarFromTensor [id V]
│ └─ Elemwise{scalar_minimum,no_inplace} [id C] │ └─ Minimum [id C]
│ └─ ··· │ └─ ···
├─ Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0) ├─ Minimum [id C] (outer_in_nit_sot-0)
│ └─ ··· │ └─ ···
└─ x [id W] (outer_in_non_seqs-0) └─ x [id W] (outer_in_non_seqs-0)
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id B] for{cpu,scan_fn} [id B]
Elemwise{mul,no_inplace} [id X] (inner_out_nit_sot-0) Mul [id X] (inner_out_nit_sot-0)
├─ *0-<TensorType(float64, ())> [id Y] -> [id S] (inner_in_seqs-0) ├─ *0-<TensorType(float64, ())> [id Y] -> [id S] (inner_in_seqs-0)
└─ Elemwise{pow,no_inplace} [id Z] └─ Pow [id Z]
├─ *2-<TensorType(float64, ())> [id BA] -> [id W] (inner_in_non_seqs-0) ├─ *2-<TensorType(float64, ())> [id BA] -> [id W] (inner_in_non_seqs-0)
└─ *1-<TensorType(int64, ())> [id BB] -> [id U] (inner_in_seqs-1)""" └─ *1-<TensorType(int64, ())> [id BB] -> [id U] (inner_in_seqs-1)"""
...@@ -225,9 +225,9 @@ def test_debugprint_nested_scans(): ...@@ -225,9 +225,9 @@ def test_debugprint_nested_scans():
output_str = debugprint(final_result, file="str", print_op_info=True) output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n") lines = output_str.split("\n")
expected_output = """Sum{acc_dtype=float64} [id A] expected_output = """Sum{axes=None} [id A]
└─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0) └─ for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
├─ Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0) ├─ Minimum [id C] (outer_in_nit_sot-0)
│ ├─ Subtensor{int64} [id D] │ ├─ Subtensor{int64} [id D]
│ │ ├─ Shape [id E] │ │ ├─ Shape [id E]
│ │ │ └─ Subtensor{int64::} [id F] 'c[0:]' │ │ │ └─ Subtensor{int64::} [id F] 'c[0:]'
...@@ -247,15 +247,15 @@ def test_debugprint_nested_scans(): ...@@ -247,15 +247,15 @@ def test_debugprint_nested_scans():
│ ├─ Subtensor{int64::} [id F] 'c[0:]' │ ├─ Subtensor{int64::} [id F] 'c[0:]'
│ │ └─ ··· │ │ └─ ···
│ └─ ScalarFromTensor [id T] │ └─ ScalarFromTensor [id T]
│ └─ Elemwise{scalar_minimum,no_inplace} [id C] │ └─ Minimum [id C]
│ └─ ··· │ └─ ···
├─ Subtensor{:int64:} [id U] (outer_in_seqs-1) ├─ Subtensor{:int64:} [id U] (outer_in_seqs-1)
│ ├─ Subtensor{int64::} [id L] │ ├─ Subtensor{int64::} [id L]
│ │ └─ ··· │ │ └─ ···
│ └─ ScalarFromTensor [id V] │ └─ ScalarFromTensor [id V]
│ └─ Elemwise{scalar_minimum,no_inplace} [id C] │ └─ Minimum [id C]
│ └─ ··· │ └─ ···
├─ Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0) ├─ Minimum [id C] (outer_in_nit_sot-0)
│ └─ ··· │ └─ ···
├─ A [id W] (outer_in_non_seqs-0) ├─ A [id W] (outer_in_non_seqs-0)
└─ k [id X] (outer_in_non_seqs-1) └─ k [id X] (outer_in_non_seqs-1)
...@@ -263,23 +263,23 @@ def test_debugprint_nested_scans(): ...@@ -263,23 +263,23 @@ def test_debugprint_nested_scans():
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id B] for{cpu,scan_fn} [id B]
Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0) Mul [id Y] (inner_out_nit_sot-0)
├─ ExpandDims{axis=0} [id Z] ├─ ExpandDims{axis=0} [id Z]
│ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0) │ └─ *0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
└─ Elemwise{pow,no_inplace} [id BB] └─ Pow [id BB]
├─ Subtensor{int64} [id BC] ├─ Subtensor{int64} [id BC]
│ ├─ Subtensor{int64::} [id BD] │ ├─ Subtensor{int64::} [id BD]
│ │ ├─ for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0) │ │ ├─ for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps) │ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─ IncSubtensor{Set;:int64:} [id BG] (outer_in_sit_sot-0) │ │ │ ├─ IncSubtensor{Set;:int64:} [id BG] (outer_in_sit_sot-0)
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH] │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BH]
│ │ │ │ │ ├─ Elemwise{add,no_inplace} [id BI] │ │ │ │ │ ├─ Add [id BI]
│ │ │ │ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) │ │ │ │ │ │ ├─ *3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
│ │ │ │ │ │ └─ Subtensor{int64} [id BJ] │ │ │ │ │ │ └─ Subtensor{int64} [id BJ]
│ │ │ │ │ │ ├─ Shape [id BK] │ │ │ │ │ │ ├─ Shape [id BK]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL] │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BL]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM] │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BM]
│ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id BN] │ │ │ │ │ │ │ └─ Second [id BN]
│ │ │ │ │ │ │ ├─ *2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0) │ │ │ │ │ │ │ ├─ *2-<TensorType(float64, (?,))> [id BO] -> [id W] (inner_in_non_seqs-0)
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP] │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BP]
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BQ] │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BQ]
...@@ -301,7 +301,7 @@ def test_debugprint_nested_scans(): ...@@ -301,7 +301,7 @@ def test_debugprint_nested_scans():
└─ *1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1) └─ *1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
for{cpu,scan_fn} [id BE] for{cpu,scan_fn} [id BE]
Elemwise{mul,no_inplace} [id CA] (inner_out_sit_sot-0) Mul [id CA] (inner_out_sit_sot-0)
├─ *0-<TensorType(float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0) ├─ *0-<TensorType(float64, (?,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
└─ *1-<TensorType(float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)""" └─ *1-<TensorType(float64, (?,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
...@@ -318,9 +318,9 @@ def test_debugprint_nested_scans(): ...@@ -318,9 +318,9 @@ def test_debugprint_nested_scans():
expected_output = """→ c [id A] expected_output = """→ c [id A]
→ k [id B] → k [id B]
→ A [id C] → A [id C]
Sum{acc_dtype=float64} [id D] 13 Sum{axes=None} [id D] 13
└─ for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0) └─ for{cpu,scan_fn} [id E] 12 (outer_out_nit_sot-0)
├─ Elemwise{scalar_minimum,no_inplace} [id F] 7 (outer_in_nit_sot-0) ├─ Minimum [id F] 7 (outer_in_nit_sot-0)
│ ├─ Subtensor{int64} [id G] 6 │ ├─ Subtensor{int64} [id G] 6
│ │ ├─ Shape [id H] 5 │ │ ├─ Shape [id H] 5
│ │ │ └─ Subtensor{int64::} [id I] 'c[0:]' 4 │ │ │ └─ Subtensor{int64::} [id I] 'c[0:]' 4
...@@ -340,15 +340,15 @@ def test_debugprint_nested_scans(): ...@@ -340,15 +340,15 @@ def test_debugprint_nested_scans():
│ ├─ Subtensor{int64::} [id I] 'c[0:]' 4 │ ├─ Subtensor{int64::} [id I] 'c[0:]' 4
│ │ └─ ··· │ │ └─ ···
│ └─ ScalarFromTensor [id V] 10 │ └─ ScalarFromTensor [id V] 10
│ └─ Elemwise{scalar_minimum,no_inplace} [id F] 7 │ └─ Minimum [id F] 7
│ └─ ··· │ └─ ···
├─ Subtensor{:int64:} [id W] 9 (outer_in_seqs-1) ├─ Subtensor{:int64:} [id W] 9 (outer_in_seqs-1)
│ ├─ Subtensor{int64::} [id N] 1 │ ├─ Subtensor{int64::} [id N] 1
│ │ └─ ··· │ │ └─ ···
│ └─ ScalarFromTensor [id X] 8 │ └─ ScalarFromTensor [id X] 8
│ └─ Elemwise{scalar_minimum,no_inplace} [id F] 7 │ └─ Minimum [id F] 7
│ └─ ··· │ └─ ···
├─ Elemwise{scalar_minimum,no_inplace} [id F] 7 (outer_in_nit_sot-0) ├─ Minimum [id F] 7 (outer_in_nit_sot-0)
│ └─ ··· │ └─ ···
├─ A [id C] (outer_in_non_seqs-0) ├─ A [id C] (outer_in_non_seqs-0)
└─ k [id B] (outer_in_non_seqs-1) └─ k [id B] (outer_in_non_seqs-1)
...@@ -360,23 +360,23 @@ def test_debugprint_nested_scans(): ...@@ -360,23 +360,23 @@ def test_debugprint_nested_scans():
→ *1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1) → *1-<TensorType(int64, ())> [id Z] -> [id W] (inner_in_seqs-1)
→ *2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0) → *2-<TensorType(float64, (?,))> [id BA] -> [id C] (inner_in_non_seqs-0)
→ *3-<TensorType(int32, ())> [id BB] -> [id B] (inner_in_non_seqs-1) → *3-<TensorType(int32, ())> [id BB] -> [id B] (inner_in_non_seqs-1)
Elemwise{mul,no_inplace} [id BC] (inner_out_nit_sot-0) Mul [id BC] (inner_out_nit_sot-0)
├─ ExpandDims{axis=0} [id BD] ├─ ExpandDims{axis=0} [id BD]
│ └─ *0-<TensorType(float64, ())> [id Y] (inner_in_seqs-0) │ └─ *0-<TensorType(float64, ())> [id Y] (inner_in_seqs-0)
└─ Elemwise{pow,no_inplace} [id BE] └─ Pow [id BE]
├─ Subtensor{int64} [id BF] ├─ Subtensor{int64} [id BF]
│ ├─ Subtensor{int64::} [id BG] │ ├─ Subtensor{int64::} [id BG]
│ │ ├─ for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0) │ │ ├─ for{cpu,scan_fn} [id BH] (outer_out_sit_sot-0)
│ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps) │ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) (n_steps)
│ │ │ ├─ IncSubtensor{Set;:int64:} [id BI] (outer_in_sit_sot-0) │ │ │ ├─ IncSubtensor{Set;:int64:} [id BI] (outer_in_sit_sot-0)
│ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ] │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id BJ]
│ │ │ │ │ ├─ Elemwise{add,no_inplace} [id BK] │ │ │ │ │ ├─ Add [id BK]
│ │ │ │ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1) │ │ │ │ │ │ ├─ *3-<TensorType(int32, ())> [id BB] (inner_in_non_seqs-1)
│ │ │ │ │ │ └─ Subtensor{int64} [id BL] │ │ │ │ │ │ └─ Subtensor{int64} [id BL]
│ │ │ │ │ │ ├─ Shape [id BM] │ │ │ │ │ │ ├─ Shape [id BM]
│ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN] │ │ │ │ │ │ │ └─ Unbroadcast{0} [id BN]
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO] │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BO]
│ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id BP] │ │ │ │ │ │ │ └─ Second [id BP]
│ │ │ │ │ │ │ ├─ *2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0) │ │ │ │ │ │ │ ├─ *2-<TensorType(float64, (?,))> [id BA] (inner_in_non_seqs-0)
│ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ] │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id BQ]
│ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BR] │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id BR]
...@@ -400,7 +400,7 @@ def test_debugprint_nested_scans(): ...@@ -400,7 +400,7 @@ def test_debugprint_nested_scans():
for{cpu,scan_fn} [id BH] for{cpu,scan_fn} [id BH]
→ *0-<TensorType(float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0) → *0-<TensorType(float64, (?,))> [id CA] -> [id BI] (inner_in_sit_sot-0)
→ *1-<TensorType(float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0) → *1-<TensorType(float64, (?,))> [id CB] -> [id BA] (inner_in_non_seqs-0)
Elemwise{mul,no_inplace} [id CC] (inner_out_sit_sot-0) Mul [id CC] (inner_out_sit_sot-0)
├─ *0-<TensorType(float64, (?,))> [id CA] (inner_in_sit_sot-0) ├─ *0-<TensorType(float64, (?,))> [id CA] (inner_in_sit_sot-0)
└─ *1-<TensorType(float64, (?,))> [id CB] (inner_in_non_seqs-0)""" └─ *1-<TensorType(float64, (?,))> [id CB] (inner_in_non_seqs-0)"""
...@@ -429,13 +429,13 @@ def test_debugprint_mitsot(): ...@@ -429,13 +429,13 @@ def test_debugprint_mitsot():
output_str = debugprint(final_result, file="str", print_op_info=True) output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n") lines = output_str.split("\n")
expected_output = """Elemwise{add,no_inplace} [id A] expected_output = """Add [id A]
├─ Subtensor{int64::} [id B] ├─ Subtensor{int64::} [id B]
│ ├─ for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0) │ ├─ for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
│ │ ├─ TensorConstant{5} [id D] (n_steps) │ │ ├─ TensorConstant{5} [id D] (n_steps)
│ │ ├─ IncSubtensor{Set;:int64:} [id E] (outer_in_mit_sot-0) │ │ ├─ IncSubtensor{Set;:int64:} [id E] (outer_in_mit_sot-0)
│ │ │ ├─ AllocEmpty{dtype='int64'} [id F] │ │ │ ├─ AllocEmpty{dtype='int64'} [id F]
│ │ │ │ └─ Elemwise{add,no_inplace} [id G] │ │ │ │ └─ Add [id G]
│ │ │ │ ├─ TensorConstant{5} [id D] │ │ │ │ ├─ TensorConstant{5} [id D]
│ │ │ │ └─ Subtensor{int64} [id H] │ │ │ │ └─ Subtensor{int64} [id H]
│ │ │ │ ├─ Shape [id I] │ │ │ │ ├─ Shape [id I]
...@@ -450,7 +450,7 @@ def test_debugprint_mitsot(): ...@@ -450,7 +450,7 @@ def test_debugprint_mitsot():
│ │ │ └─ ··· │ │ │ └─ ···
│ │ └─ IncSubtensor{Set;:int64:} [id O] (outer_in_mit_sot-1) │ │ └─ IncSubtensor{Set;:int64:} [id O] (outer_in_mit_sot-1)
│ │ ├─ AllocEmpty{dtype='int64'} [id P] │ │ ├─ AllocEmpty{dtype='int64'} [id P]
│ │ │ └─ Elemwise{add,no_inplace} [id Q] │ │ │ └─ Add [id Q]
│ │ │ ├─ TensorConstant{5} [id D] │ │ │ ├─ TensorConstant{5} [id D]
│ │ │ └─ Subtensor{int64} [id R] │ │ │ └─ Subtensor{int64} [id R]
│ │ │ ├─ Shape [id S] │ │ │ ├─ Shape [id S]
...@@ -472,10 +472,10 @@ def test_debugprint_mitsot(): ...@@ -472,10 +472,10 @@ def test_debugprint_mitsot():
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id C] for{cpu,scan_fn} [id C]
Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0) Add [id BB] (inner_out_mit_sot-0)
├─ *1-<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1) ├─ *1-<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
└─ *0-<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0) └─ *0-<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1) Add [id BE] (inner_out_mit_sot-1)
├─ *3-<TensorType(int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1) ├─ *3-<TensorType(int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
└─ *2-<TensorType(int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)""" └─ *2-<TensorType(int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)"""
...@@ -503,20 +503,20 @@ def test_debugprint_mitmot(): ...@@ -503,20 +503,20 @@ def test_debugprint_mitmot():
expected_output = """Subtensor{int64} [id A] expected_output = """Subtensor{int64} [id A]
├─ for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0) ├─ for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
│ ├─ Elemwise{sub,no_inplace} [id C] (n_steps) │ ├─ Sub [id C] (n_steps)
│ │ ├─ Subtensor{int64} [id D] │ │ ├─ Subtensor{int64} [id D]
│ │ │ ├─ Shape [id E] │ │ │ ├─ Shape [id E]
│ │ │ │ └─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) │ │ │ │ └─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ ├─ k [id G] (n_steps) │ │ │ │ ├─ k [id G] (n_steps)
│ │ │ │ ├─ IncSubtensor{Set;:int64:} [id H] (outer_in_sit_sot-0) │ │ │ │ ├─ IncSubtensor{Set;:int64:} [id H] (outer_in_sit_sot-0)
│ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I] │ │ │ │ │ ├─ AllocEmpty{dtype='float64'} [id I]
│ │ │ │ │ │ ├─ Elemwise{add,no_inplace} [id J] │ │ │ │ │ │ ├─ Add [id J]
│ │ │ │ │ │ │ ├─ k [id G] │ │ │ │ │ │ │ ├─ k [id G]
│ │ │ │ │ │ │ └─ Subtensor{int64} [id K] │ │ │ │ │ │ │ └─ Subtensor{int64} [id K]
│ │ │ │ │ │ │ ├─ Shape [id L] │ │ │ │ │ │ │ ├─ Shape [id L]
│ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M] │ │ │ │ │ │ │ │ └─ Unbroadcast{0} [id M]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N] │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id N]
│ │ │ │ │ │ │ │ └─ Elemwise{second,no_inplace} [id O] │ │ │ │ │ │ │ │ └─ Second [id O]
│ │ │ │ │ │ │ │ ├─ A [id P] │ │ │ │ │ │ │ │ ├─ A [id P]
│ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q] │ │ │ │ │ │ │ │ └─ ExpandDims{axis=0} [id Q]
│ │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id R] │ │ │ │ │ │ │ │ └─ TensorConstant{1.0} [id R]
...@@ -542,7 +542,7 @@ def test_debugprint_mitmot(): ...@@ -542,7 +542,7 @@ def test_debugprint_mitmot():
│ │ │ │ └─ ScalarConstant{-1} [id BC] │ │ │ │ └─ ScalarConstant{-1} [id BC]
│ │ │ └─ ScalarConstant{-1} [id BD] │ │ │ └─ ScalarConstant{-1} [id BD]
│ │ └─ ScalarFromTensor [id BE] │ │ └─ ScalarFromTensor [id BE]
│ │ └─ Elemwise{sub,no_inplace} [id C] │ │ └─ Sub [id C]
│ │ └─ ··· │ │ └─ ···
│ ├─ Subtensor{:int64:} [id BF] (outer_in_seqs-1) │ ├─ Subtensor{:int64:} [id BF] (outer_in_seqs-1)
│ │ ├─ Subtensor{:int64:} [id BG] │ │ ├─ Subtensor{:int64:} [id BG]
...@@ -552,31 +552,31 @@ def test_debugprint_mitmot(): ...@@ -552,31 +552,31 @@ def test_debugprint_mitmot():
│ │ │ │ └─ ScalarConstant{-1} [id BI] │ │ │ │ └─ ScalarConstant{-1} [id BI]
│ │ │ └─ ScalarConstant{-1} [id BJ] │ │ │ └─ ScalarConstant{-1} [id BJ]
│ │ └─ ScalarFromTensor [id BK] │ │ └─ ScalarFromTensor [id BK]
│ │ └─ Elemwise{sub,no_inplace} [id C] │ │ └─ Sub [id C]
│ │ └─ ··· │ │ └─ ···
│ ├─ Subtensor{::int64} [id BL] (outer_in_mit_mot-0) │ ├─ Subtensor{::int64} [id BL] (outer_in_mit_mot-0)
│ │ ├─ IncSubtensor{Inc;int64::} [id BM] │ │ ├─ IncSubtensor{Inc;int64::} [id BM]
│ │ │ ├─ Elemwise{second,no_inplace} [id BN] │ │ │ ├─ Second [id BN]
│ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ └─ ··· │ │ │ │ │ └─ ···
│ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO] │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BO]
│ │ │ │ └─ TensorConstant{0.0} [id BP] │ │ │ │ └─ TensorConstant{0.0} [id BP]
│ │ │ ├─ IncSubtensor{Inc;int64} [id BQ] │ │ │ ├─ IncSubtensor{Inc;int64} [id BQ]
│ │ │ │ ├─ Elemwise{second,no_inplace} [id BR] │ │ │ │ ├─ Second [id BR]
│ │ │ │ │ ├─ Subtensor{int64::} [id BS] │ │ │ │ │ ├─ Subtensor{int64::} [id BS]
│ │ │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) │ │ │ │ │ │ ├─ for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
│ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ScalarConstant{1} [id BT] │ │ │ │ │ │ └─ ScalarConstant{1} [id BT]
│ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU] │ │ │ │ │ └─ ExpandDims{axes=[0, 1]} [id BU]
│ │ │ │ │ └─ TensorConstant{0.0} [id BV] │ │ │ │ │ └─ TensorConstant{0.0} [id BV]
│ │ │ │ ├─ Elemwise{second} [id BW] │ │ │ │ ├─ Second [id BW]
│ │ │ │ │ ├─ Subtensor{int64} [id BX] │ │ │ │ │ ├─ Subtensor{int64} [id BX]
│ │ │ │ │ │ ├─ Subtensor{int64::} [id BS] │ │ │ │ │ │ ├─ Subtensor{int64::} [id BS]
│ │ │ │ │ │ │ └─ ··· │ │ │ │ │ │ │ └─ ···
│ │ │ │ │ │ └─ ScalarConstant{-1} [id BY] │ │ │ │ │ │ └─ ScalarConstant{-1} [id BY]
│ │ │ │ │ └─ ExpandDims{axis=0} [id BZ] │ │ │ │ │ └─ ExpandDims{axis=0} [id BZ]
│ │ │ │ │ └─ Elemwise{second,no_inplace} [id CA] │ │ │ │ │ └─ Second [id CA]
│ │ │ │ │ ├─ Sum{acc_dtype=float64} [id CB] │ │ │ │ │ ├─ Sum{axes=None} [id CB]
│ │ │ │ │ │ └─ Subtensor{int64} [id BX] │ │ │ │ │ │ └─ Subtensor{int64} [id BX]
│ │ │ │ │ │ └─ ··· │ │ │ │ │ │ └─ ···
│ │ │ │ │ └─ TensorConstant{1.0} [id CC] │ │ │ │ │ └─ TensorConstant{1.0} [id CC]
...@@ -585,8 +585,8 @@ def test_debugprint_mitmot(): ...@@ -585,8 +585,8 @@ def test_debugprint_mitmot():
│ │ └─ ScalarConstant{-1} [id CD] │ │ └─ ScalarConstant{-1} [id CD]
│ ├─ Alloc [id CE] (outer_in_sit_sot-0) │ ├─ Alloc [id CE] (outer_in_sit_sot-0)
│ │ ├─ TensorConstant{0.0} [id CF] │ │ ├─ TensorConstant{0.0} [id CF]
│ │ ├─ Elemwise{add,no_inplace} [id CG] │ │ ├─ Add [id CG]
│ │ │ ├─ Elemwise{sub,no_inplace} [id C] │ │ │ ├─ Sub [id C]
│ │ │ │ └─ ··· │ │ │ │ └─ ···
│ │ │ └─ TensorConstant{1} [id CH] │ │ │ └─ TensorConstant{1} [id CH]
│ │ └─ Subtensor{int64} [id CI] │ │ └─ Subtensor{int64} [id CI]
...@@ -599,19 +599,19 @@ def test_debugprint_mitmot(): ...@@ -599,19 +599,19 @@ def test_debugprint_mitmot():
Inner graphs: Inner graphs:
for{cpu,grad_of_scan_fn} [id B] for{cpu,grad_of_scan_fn} [id B]
Elemwise{add,no_inplace} [id CM] (inner_out_mit_mot-0-0) Add [id CM] (inner_out_mit_mot-0-0)
├─ Elemwise{mul} [id CN] ├─ Mul [id CN]
│ ├─ *2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0) │ ├─ *2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
│ └─ *5-<TensorType(float64, (?,))> [id CP] -> [id P] (inner_in_non_seqs-0) │ └─ *5-<TensorType(float64, (?,))> [id CP] -> [id P] (inner_in_non_seqs-0)
└─ *3-<TensorType(float64, (?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1) └─ *3-<TensorType(float64, (?,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
Elemwise{add,no_inplace} [id CR] (inner_out_sit_sot-0) Add [id CR] (inner_out_sit_sot-0)
├─ Elemwise{mul} [id CS] ├─ Mul [id CS]
│ ├─ *2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0) │ ├─ *2-<TensorType(float64, (?,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
│ └─ *0-<TensorType(float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0) │ └─ *0-<TensorType(float64, (?,))> [id CT] -> [id Z] (inner_in_seqs-0)
└─ *4-<TensorType(float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0) └─ *4-<TensorType(float64, (?,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
for{cpu,scan_fn} [id F] for{cpu,scan_fn} [id F]
Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0) Mul [id CV] (inner_out_sit_sot-0)
├─ *0-<TensorType(float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0) ├─ *0-<TensorType(float64, (?,))> [id CT] -> [id H] (inner_in_sit_sot-0)
└─ *1-<TensorType(float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)""" └─ *1-<TensorType(float64, (?,))> [id CW] -> [id P] (inner_in_non_seqs-0)"""
...@@ -654,7 +654,7 @@ def test_debugprint_compiled_fn(): ...@@ -654,7 +654,7 @@ def test_debugprint_compiled_fn():
Inner graphs: Inner graphs:
forall_inplace,cpu,scan_fn} [id A] forall_inplace,cpu,scan_fn} [id A]
Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0) Composite{switch(lt(i0, i1), i2, i0)} [id I] (inner_out_sit_sot-0)
├─ TensorConstant{0} [id J] ├─ TensorConstant{0} [id J]
├─ Subtensor{int64, int64, uint8} [id K] ├─ Subtensor{int64, int64, uint8} [id K]
│ ├─ *2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0) │ ├─ *2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
...@@ -665,7 +665,7 @@ def test_debugprint_compiled_fn(): ...@@ -665,7 +665,7 @@ def test_debugprint_compiled_fn():
│ └─ ScalarConstant{0} [id Q] │ └─ ScalarConstant{0} [id Q]
└─ TensorConstant{1} [id R] └─ TensorConstant{1} [id R]
Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] Composite{switch(lt(i0, i1), i2, i0)} [id I]
← Switch [id S] 'o0' ← Switch [id S] 'o0'
├─ LT [id T] ├─ LT [id T]
│ ├─ i0 [id U] │ ├─ i0 [id U]
......
...@@ -676,14 +676,9 @@ class TestCAReduce(unittest_tools.InferShapeTester): ...@@ -676,14 +676,9 @@ class TestCAReduce(unittest_tools.InferShapeTester):
def test_str(self): def test_str(self):
op = CAReduce(aes.add, axis=None) op = CAReduce(aes.add, axis=None)
assert str(op) == "CAReduce{add}" assert str(op) == "CAReduce{add, axes=None}"
op = CAReduce(aes.add, axis=(1,)) op = CAReduce(aes.add, axis=(1,))
assert str(op) == "CAReduce{add}{axis=[1]}" assert str(op) == "CAReduce{add, axis=1}"
op = CAReduce(aes.add, axis=None, acc_dtype="float64")
assert str(op) == "CAReduce{add}{acc_dtype=float64}"
op = CAReduce(aes.add, axis=(1,), acc_dtype="float64")
assert str(op) == "CAReduce{add}{axis=[1], acc_dtype=float64}"
def test_repeated_axis(self): def test_repeated_axis(self):
x = vector("x") x = vector("x")
...@@ -802,10 +797,8 @@ class TestElemwise(unittest_tools.InferShapeTester): ...@@ -802,10 +797,8 @@ class TestElemwise(unittest_tools.InferShapeTester):
self.check_input_dimensions_match(Mode(linker="c")) self.check_input_dimensions_match(Mode(linker="c"))
def test_str(self): def test_str(self):
op = Elemwise(aes.add, inplace_pattern=None, name=None)
assert str(op) == "Elemwise{add}"
op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None) op = Elemwise(aes.add, inplace_pattern={0: 0}, name=None)
assert str(op) == "Elemwise{add}[(0, 0)]" assert str(op) == "Add"
op = Elemwise(aes.add, inplace_pattern=None, name="my_op") op = Elemwise(aes.add, inplace_pattern=None, name="my_op")
assert str(op) == "my_op" assert str(op) == "my_op"
......
...@@ -106,9 +106,9 @@ def test_min_informative_str(): ...@@ -106,9 +106,9 @@ def test_min_informative_str():
mis = min_informative_str(G).replace("\t", " ") mis = min_informative_str(G).replace("\t", " ")
reference = """A. Elemwise{add,no_inplace} reference = """A. Add
B. C B. C
C. Elemwise{add,no_inplace} C. Add
D. D D. D
E. E""" E. E"""
...@@ -144,11 +144,11 @@ def test_debugprint(): ...@@ -144,11 +144,11 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} [id 0] Add [id 0]
├─ Elemwise{add,no_inplace} [id 1] 'C' ├─ Add [id 1] 'C'
│ ├─ A [id 2] │ ├─ A [id 2]
│ └─ B [id 3] │ └─ B [id 3]
└─ Elemwise{add,no_inplace} [id 4] └─ Add [id 4]
├─ D [id 5] ├─ D [id 5]
└─ E [id 6] └─ E [id 6]
""" """
...@@ -162,11 +162,11 @@ def test_debugprint(): ...@@ -162,11 +162,11 @@ def test_debugprint():
# The additional white space are needed! # The additional white space are needed!
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} [id A] Add [id A]
├─ Elemwise{add,no_inplace} [id B] 'C' ├─ Add [id B] 'C'
│ ├─ A [id C] │ ├─ A [id C]
│ └─ B [id D] │ └─ B [id D]
└─ Elemwise{add,no_inplace} [id E] └─ Add [id E]
├─ D [id F] ├─ D [id F]
└─ E [id G] └─ E [id G]
""" """
...@@ -180,10 +180,10 @@ def test_debugprint(): ...@@ -180,10 +180,10 @@ def test_debugprint():
# The additional white space are needed! # The additional white space are needed!
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} [id A] Add [id A]
├─ Elemwise{add,no_inplace} [id B] 'C' ├─ Add [id B] 'C'
│ └─ ··· │ └─ ···
└─ Elemwise{add,no_inplace} [id C] └─ Add [id C]
├─ D [id D] ├─ D [id D]
└─ E [id E] └─ E [id E]
""" """
...@@ -196,11 +196,11 @@ def test_debugprint(): ...@@ -196,11 +196,11 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} Add
├─ Elemwise{add,no_inplace} 'C' ├─ Add 'C'
│ ├─ A │ ├─ A
│ └─ B │ └─ B
└─ Elemwise{add,no_inplace} └─ Add
├─ D ├─ D
└─ E └─ E
""" """
...@@ -213,7 +213,7 @@ def test_debugprint(): ...@@ -213,7 +213,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} 0 [None] Add 0 [None]
├─ A [None] ├─ A [None]
├─ B [None] ├─ B [None]
├─ D [None] ├─ D [None]
...@@ -231,7 +231,7 @@ def test_debugprint(): ...@@ -231,7 +231,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} 0 [None] Add 0 [None]
├─ A [None] ├─ A [None]
├─ B [None] ├─ B [None]
├─ D [None] ├─ D [None]
...@@ -249,7 +249,7 @@ def test_debugprint(): ...@@ -249,7 +249,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
reference = dedent( reference = dedent(
r""" r"""
Elemwise{add,no_inplace} 0 [None] Add 0 [None]
├─ A [None] ├─ A [None]
├─ B [None] ├─ B [None]
├─ D [None] ├─ D [None]
...@@ -274,7 +274,7 @@ def test_debugprint(): ...@@ -274,7 +274,7 @@ def test_debugprint():
s = s.getvalue() s = s.getvalue()
exp_res = dedent( exp_res = dedent(
r""" r"""
Elemwise{Composite{(i2 + (i0 - i1))}} 4 Composite{(i2 + (i0 - i1))} 4
├─ ExpandDims{axis=0} v={0: [0]} 3 ├─ ExpandDims{axis=0} v={0: [0]} 3
│ └─ CGemv{inplace} d={0: [0]} 2 │ └─ CGemv{inplace} d={0: [0]} 2
│ ├─ AllocEmpty{dtype='float64'} 1 │ ├─ AllocEmpty{dtype='float64'} 1
...@@ -289,7 +289,7 @@ def test_debugprint(): ...@@ -289,7 +289,7 @@ def test_debugprint():
Inner graphs: Inner graphs:
Elemwise{Composite{(i2 + (i0 - i1))}} Composite{(i2 + (i0 - i1))}
← add 'o0' ← add 'o0'
├─ i2 ├─ i2
└─ sub └─ sub
...@@ -314,7 +314,7 @@ def test_debugprint_id_type(): ...@@ -314,7 +314,7 @@ def test_debugprint_id_type():
debugprint(e_at, id_type="auto", file=s) debugprint(e_at, id_type="auto", file=s)
s = s.getvalue() s = s.getvalue()
exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}] exp_res = f"""Add [id {e_at.auto_name}]
├─ dot [id {d_at.auto_name}] ├─ dot [id {d_at.auto_name}]
│ ├─ <TensorType(float64, (?, ?))> [id {b_at.auto_name}] │ ├─ <TensorType(float64, (?, ?))> [id {b_at.auto_name}]
│ └─ <TensorType(float64, (?,))> [id {a_at.auto_name}] │ └─ <TensorType(float64, (?,))> [id {a_at.auto_name}]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论