提交 b4c1bbd5 authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add Scan input/output type annotations to debugprint output

上级 c7b416e1
...@@ -11,7 +11,7 @@ import sys ...@@ -11,7 +11,7 @@ import sys
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from contextlib import contextmanager from contextlib import contextmanager
from copy import copy from copy import copy
from functools import reduce from functools import reduce, singledispatch
from io import IOBase, StringIO from io import IOBase, StringIO
from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, List, Optional, Tuple, Union
...@@ -83,6 +83,26 @@ def char_from_number(number): ...@@ -83,6 +83,26 @@ def char_from_number(number):
return rval return rval
@singledispatch
def op_debug_information(op: Op, node: Apply) -> Dict[Apply, Dict[Variable, str]]:
"""Provide extra debug print information based on the type of `Op` and `Apply` node.
Implementations of this dispatch function should return a ``dict`` keyed by
the `Apply` node, `node`, associated with the given `op`. The value
associated with the `node` is another ``dict`` mapping `Variable` inputs
and/or outputs of `node` to their debug information.
The `node` key allows the information in the ``dict``'s values to be
specific to the given `node`, so that--for instance--the provided debug
information is only ever printed/associated with a given `Variable`
input/output when that `Variable` is displayed as an input/output of `node`
and not in every/any other place where said `Variable` is present in a
graph.
"""
return {}
def debugprint( def debugprint(
obj: Union[ obj: Union[
Union[Variable, Apply, Function], List[Union[Variable, Apply, Function]] Union[Variable, Apply, Function], List[Union[Variable, Apply, Function]]
...@@ -95,8 +115,9 @@ def debugprint( ...@@ -95,8 +115,9 @@ def debugprint(
done: Optional[Dict[Apply, str]] = None, done: Optional[Dict[Apply, str]] = None,
print_storage: bool = False, print_storage: bool = False,
used_ids: Optional[Dict[Variable, str]] = None, used_ids: Optional[Dict[Variable, str]] = None,
print_op_info: bool = False,
) -> Union[str, IOBase]: ) -> Union[str, IOBase]:
"""Print a computation graph as text to stdout or a file. r"""Print a computation graph as text to stdout or a file.
Each line printed represents a Variable in the graph. Each line printed represents a Variable in the graph.
The indentation of lines corresponds to its depth in the symbolic graph. The indentation of lines corresponds to its depth in the symbolic graph.
...@@ -145,6 +166,9 @@ def debugprint( ...@@ -145,6 +166,9 @@ def debugprint(
function, the output will show the intermediate results. function, the output will show the intermediate results.
used_ids used_ids
A map between nodes and their printed ids. A map between nodes and their printed ids.
print_op_info
Print extra information provided by the relevant `Op`\s. For example,
print the tap information for `Scan` inputs and outputs.
Returns Returns
------- -------
...@@ -236,9 +260,15 @@ N.B.: ...@@ -236,9 +260,15 @@ N.B.:
file=_file, file=_file,
) )
op_information = {}
for r, p, s, o in zip(results_to_print, profile_list, smap, order): for r, p, s, o in zip(results_to_print, profile_list, smap, order):
if hasattr(r.owner, "op") and isinstance(r.owner.op, HasInnerGraph):
inner_graph_ops.append(r) if hasattr(r.owner, "op"):
if isinstance(r.owner.op, HasInnerGraph):
inner_graph_ops.append(r)
if print_op_info:
op_information.update(op_debug_information(r.owner.op, r.owner))
_debugprint( _debugprint(
r, r,
...@@ -253,6 +283,9 @@ N.B.: ...@@ -253,6 +283,9 @@ N.B.:
profile=p, profile=p,
smap=s, smap=s,
used_ids=used_ids, used_ids=used_ids,
op_information=op_information,
parent_node=r.owner,
print_op_info=print_op_info,
) )
if len(inner_graph_ops) > 0: if len(inner_graph_ops) > 0:
...@@ -288,6 +321,9 @@ N.B.: ...@@ -288,6 +321,9 @@ N.B.:
else: else:
inner_to_outer_inputs = None inner_to_outer_inputs = None
if print_op_info:
op_information.update(op_debug_information(s.owner.op, s.owner))
print("", file=_file) print("", file=_file)
_debugprint( _debugprint(
...@@ -301,6 +337,9 @@ N.B.: ...@@ -301,6 +337,9 @@ N.B.:
stop_on_name=stop_on_name, stop_on_name=stop_on_name,
inner_to_outer_inputs=inner_to_outer_inputs, inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids, used_ids=used_ids,
op_information=op_information,
parent_node=s.owner,
print_op_info=print_op_info,
) )
for idx, i in enumerate(inner_outputs): for idx, i in enumerate(inner_outputs):
...@@ -321,6 +360,9 @@ N.B.: ...@@ -321,6 +360,9 @@ N.B.:
inner_graph_ops=inner_graph_ops, inner_graph_ops=inner_graph_ops,
inner_to_outer_inputs=inner_to_outer_inputs, inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids, used_ids=used_ids,
op_information=op_information,
parent_node=s.owner,
print_op_info=print_op_info,
) )
if file is _file: if file is _file:
...@@ -350,6 +392,9 @@ def _debugprint( ...@@ -350,6 +392,9 @@ def _debugprint(
inner_to_outer_inputs: Optional[Dict[Variable, Variable]] = None, inner_to_outer_inputs: Optional[Dict[Variable, Variable]] = None,
smap: Optional[StorageMapType] = None, smap: Optional[StorageMapType] = None,
used_ids: Optional[Dict[Variable, str]] = None, used_ids: Optional[Dict[Variable, str]] = None,
op_information: Optional[Dict[Apply, Dict[Variable, str]]] = None,
parent_node: Optional[Apply] = None,
print_op_info: bool = False,
) -> IOBase: ) -> IOBase:
r"""Print the graph leading to `r`. r"""Print the graph leading to `r`.
...@@ -395,6 +440,13 @@ def _debugprint( ...@@ -395,6 +440,13 @@ def _debugprint(
A map between nodes and their printed ids. A map between nodes and their printed ids.
It wasn't always printed, but at least a reference to it was printed. It wasn't always printed, but at least a reference to it was printed.
Internal. Used to pass information when recursing. Internal. Used to pass information when recursing.
op_information
Extra `Op`-level information to be added to variable print-outs.
parent_node
The parent node of `r`.
print_op_info
Print extra information provided by the relevant `Op`\s. For example,
print the tap information for `Scan` inputs and outputs.
""" """
if depth == 0: if depth == 0:
return file return file
...@@ -419,6 +471,9 @@ def _debugprint( ...@@ -419,6 +471,9 @@ def _debugprint(
if used_ids is None: if used_ids is None:
used_ids = dict() used_ids = dict()
if op_information is None:
op_information = {}
def get_id_str(obj, get_printed=True) -> str: def get_id_str(obj, get_printed=True) -> str:
id_str: str = "" id_str: str = ""
if obj in used_ids: if obj in used_ids:
...@@ -441,15 +496,16 @@ def _debugprint( ...@@ -441,15 +496,16 @@ def _debugprint(
return id_str return id_str
if hasattr(r.owner, "op"): if hasattr(r.owner, "op"):
# this variable is the output of computation, # This variable is the output of a computation, so just print out the
# so just print out the apply # `Apply` node
a = r.owner a = r.owner
r_name = getattr(r, "name", "") r_name = getattr(r, "name", "")
# normally if the name isn't set, it'll be None, so
# r_name is None here
if r_name is None: if r_name is None:
r_name = "" r_name = ""
if r_name:
r_name = f" '{r_name}'"
if print_destroy_map: if print_destroy_map:
destroy_map_str = str(r.owner.op.destroy_map) destroy_map_str = str(r.owner.op.destroy_map)
...@@ -460,31 +516,45 @@ def _debugprint( ...@@ -460,31 +516,45 @@ def _debugprint(
view_map_str = str(r.owner.op.view_map) view_map_str = str(r.owner.op.view_map)
else: else:
view_map_str = "" view_map_str = ""
if destroy_map_str and destroy_map_str != "{}": if destroy_map_str and destroy_map_str != "{}":
destroy_map_str = "d=" + destroy_map_str destroy_map_str = f" d={destroy_map_str} "
if view_map_str and view_map_str != "{}": if view_map_str and view_map_str != "{}":
view_map_str = "v=" + view_map_str view_map_str = f" v={view_map_str} "
o = ""
if order: if order:
o = str(order.index(r.owner)) o = f" {order.index(r.owner)}"
else:
o = ""
already_printed = a in done # get_id_str put it in the dict already_done = a in done
id_str = get_id_str(a) id_str = get_id_str(a)
if len(a.outputs) == 1: if len(a.outputs) == 1:
idx = "" idx = ""
else: else:
idx = f".{a.outputs.index(r)}" idx = f".{a.outputs.index(r)}"
data = ""
if smap: if id_str:
data = " " + str(smap.get(a.outputs[0], "")) id_str = f" {id_str}"
clients = ""
if smap and a.outputs[0] in smap:
data = f" {smap[a.outputs[0]]}"
else:
data = ""
var_output = f"{prefix}{a.op}{idx}{id_str}{type_str}{r_name}{destroy_map_str}{view_map_str}{o}{data}"
if print_op_info and r.owner not in op_information:
op_information.update(op_debug_information(r.owner.op, r.owner))
node_info = op_information.get(parent_node) or op_information.get(r.owner)
if node_info and r in node_info:
var_output = f"{var_output} ({node_info[r]})"
if profile is None or a not in profile.apply_time: if profile is None or a not in profile.apply_time:
print( print(var_output, file=file)
f"{prefix}{a.op}{idx} {id_str}{type_str} '{r_name}' {destroy_map_str} {view_map_str} {o}{data}{clients}",
file=file,
)
else: else:
op_time = profile.apply_time[a] op_time = profile.apply_time[a]
op_time_percent = (op_time / profile.fct_call_time) * 100 op_time_percent = (op_time / profile.fct_call_time) * 100
...@@ -492,25 +562,10 @@ def _debugprint( ...@@ -492,25 +562,10 @@ def _debugprint(
tot_time = tot_time_dict[a] tot_time = tot_time_dict[a]
tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100 tot_time_percent = (tot_time_dict[a] / profile.fct_call_time) * 100
if len(a.outputs) == 1:
idx = ""
else:
idx = f".{a.outputs.index(r)}"
print( print(
"%s%s%s %s%s '%s' %s %s %s%s%s --> " "%s --> %8.2es %4.1f%% %8.2es %4.1f%%"
"%8.2es %4.1f%% %8.2es %4.1f%%"
% ( % (
prefix, var_output,
a.op,
idx,
id_str,
type_str,
r_name,
destroy_map_str,
view_map_str,
o,
data,
clients,
op_time, op_time,
op_time_percent, op_time_percent,
tot_time, tot_time,
...@@ -519,40 +574,59 @@ def _debugprint( ...@@ -519,40 +574,59 @@ def _debugprint(
file=file, file=file,
) )
if not already_printed: if not already_done and (
if not stop_on_name or not (hasattr(r, "name") and r.name is not None): not stop_on_name or not (hasattr(r, "name") and r.name is not None)
new_prefix = prefix_child + " |" ):
new_prefix_child = prefix_child + " |" new_prefix = prefix_child + " |"
new_prefix_child = prefix_child + " |"
for idx, i in enumerate(a.inputs):
if idx == len(a.inputs) - 1: for idx, i in enumerate(a.inputs):
new_prefix_child = prefix_child + " " if idx == len(a.inputs) - 1:
new_prefix_child = prefix_child + " "
if hasattr(i, "owner") and hasattr(i.owner, "op"):
if isinstance(i.owner.op, HasInnerGraph): if hasattr(i, "owner") and hasattr(i.owner, "op"):
inner_graph_ops.append(i) if isinstance(i.owner.op, HasInnerGraph):
inner_graph_ops.append(i)
_debugprint(
i, _debugprint(
new_prefix, i,
depth=depth - 1, new_prefix,
done=done, depth=depth - 1,
print_type=print_type, done=done,
file=file, print_type=print_type,
order=order, file=file,
ids=ids, order=order,
stop_on_name=stop_on_name, ids=ids,
prefix_child=new_prefix_child, stop_on_name=stop_on_name,
inner_graph_ops=inner_graph_ops, prefix_child=new_prefix_child,
profile=profile, inner_graph_ops=inner_graph_ops,
inner_to_outer_inputs=inner_to_outer_inputs, profile=profile,
smap=smap, inner_to_outer_inputs=inner_to_outer_inputs,
used_ids=used_ids, smap=smap,
) used_ids=used_ids,
op_information=op_information,
parent_node=a,
print_op_info=print_op_info,
)
else: else:
id_str = get_id_str(r)
if id_str:
id_str = f" {id_str}"
if smap and r in smap:
data = f" {smap[r]}"
else:
data = ""
var_output = f"{prefix}{r}{id_str}{type_str}{data}"
if print_op_info and r.owner and r.owner not in op_information:
op_information.update(op_debug_information(r.owner.op, r.owner))
if inner_to_outer_inputs is not None and r in inner_to_outer_inputs: if inner_to_outer_inputs is not None and r in inner_to_outer_inputs:
id_str = get_id_str(r)
outer_r = inner_to_outer_inputs[r] outer_r = inner_to_outer_inputs[r]
if outer_r.owner: if outer_r.owner:
...@@ -560,17 +634,22 @@ def _debugprint( ...@@ -560,17 +634,22 @@ def _debugprint(
else: else:
outer_id_str = get_id_str(outer_r) outer_id_str = get_id_str(outer_r)
print( var_output = f"{var_output} -> {outer_id_str}"
f"{prefix}{r} {id_str}{type_str} -> {outer_id_str}",
file=file, # This is an inner-graph input, so we need to find the outer node
) # it belongs to and get the extra information from that
else: for inner_graph in inner_graph_ops:
# this is an input variable if outer_r in inner_graph.owner.inputs:
data = "" node_info = op_information.get(inner_graph.owner)
if smap: if node_info and r in node_info:
data = " " + str(smap.get(r, "")) var_output = f"{var_output} ({node_info[r]})"
id_str = get_id_str(r) break
print(f"{prefix}{r} {id_str}{type_str}{data}", file=file)
node_info = op_information.get(parent_node) or op_information.get(r.owner)
if node_info and r in node_info:
var_output = f"{var_output} ({node_info[r]})"
print(var_output, file=file)
return file return file
......
...@@ -76,6 +76,7 @@ from aesara.graph.utils import MissingInputError ...@@ -76,6 +76,7 @@ from aesara.graph.utils import MissingInputError
from aesara.link.c.basic import CLinker from aesara.link.c.basic import CLinker
from aesara.link.c.exceptions import MissingGXX from aesara.link.c.exceptions import MissingGXX
from aesara.link.utils import raise_with_op from aesara.link.utils import raise_with_op
from aesara.printing import op_debug_information
from aesara.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new from aesara.scan.utils import ScanProfileStats, Validator, forced_replace, safe_new
from aesara.tensor.basic import as_tensor_variable from aesara.tensor.basic import as_tensor_variable
from aesara.tensor.math import minimum from aesara.tensor.math import minimum
...@@ -3303,3 +3304,44 @@ def profile_printer( ...@@ -3303,3 +3304,44 @@ def profile_printer(
), ),
file=file, file=file,
) )
@op_debug_information.register(Scan) # type: ignore
def _op_debug_information_Scan(op, node):
from typing import Sequence
from aesara.scan.utils import ScanArgs
extra_information = {}
inner_fn = getattr(op, "_fn", None)
if inner_fn:
inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs
else:
inner_inputs = op.inputs
inner_outputs = op.outputs
scan_args = ScanArgs(
node.inputs,
node.outputs,
inner_inputs,
inner_outputs,
node.op.info,
node.op.as_while,
clone=False,
)
for field_name in scan_args.field_names:
field_vars = getattr(scan_args, field_name)
if isinstance(field_vars, Sequence):
for i, var in enumerate(field_vars):
if isinstance(var, Sequence):
for j, sub_var in enumerate(var):
extra_information[sub_var] = f"{field_name}-{i}-{j}"
else:
extra_information[var] = f"{field_name}-{i}"
else:
extra_information[field_vars] = field_name
return {node: extra_information}
...@@ -544,17 +544,17 @@ def test_debugprint(): ...@@ -544,17 +544,17 @@ def test_debugprint():
output_str = debugprint(out, file="str") output_str = debugprint(out, file="str")
lines = output_str.split("\n") lines = output_str.split("\n")
exp_res = """OpFromGraph{inline=False} [id A] '' exp_res = """OpFromGraph{inline=False} [id A]
|x [id B] |x [id B]
|y [id C] |y [id C]
|z [id D] |z [id D]
Inner graphs: Inner graphs:
OpFromGraph{inline=False} [id A] '' OpFromGraph{inline=False} [id A]
>Elemwise{add,no_inplace} [id E] '' >Elemwise{add,no_inplace} [id E]
> |x [id F] > |x [id F]
> |Elemwise{mul,no_inplace} [id G] '' > |Elemwise{mul,no_inplace} [id G]
> |y [id H] > |y [id H]
> |z [id I] > |z [id I]
""" """
......
...@@ -7,7 +7,7 @@ from aesara.printing import debugprint, pydot_imported, pydotprint ...@@ -7,7 +7,7 @@ from aesara.printing import debugprint, pydot_imported, pydotprint
from aesara.tensor.type import dvector, iscalar, scalar, vector from aesara.tensor.type import dvector, iscalar, scalar, vector
def test_scan_debugprint1(): def test_debugprint_sitsot():
k = iscalar("k") k = iscalar("k")
A = dvector("A") A = dvector("A")
...@@ -20,41 +20,96 @@ def test_scan_debugprint1(): ...@@ -20,41 +20,96 @@ def test_scan_debugprint1():
) )
final_result = result[-1] final_result = result[-1]
output_str = debugprint(final_result, file="str") output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n") lines = output_str.split("\n")
expected_output = """Subtensor{int64} [id A] '' expected_output = """Subtensor{int64} [id A]
|Subtensor{int64::} [id B] '' |Subtensor{int64::} [id B]
| |for{cpu,scan_fn} [id C] '' | |for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
| | |k [id D] (n_steps)
| | |IncSubtensor{Set;:int64:} [id E] (outer_in_sit_sot-0)
| | | |AllocEmpty{dtype='float64'} [id F]
| | | | |Elemwise{add,no_inplace} [id G]
| | | | | |k [id D]
| | | | | |Subtensor{int64} [id H]
| | | | | |Shape [id I]
| | | | | | |Rebroadcast{(0, False)} [id J]
| | | | | | |InplaceDimShuffle{x,0} [id K]
| | | | | | |Elemwise{second,no_inplace} [id L]
| | | | | | |A [id M]
| | | | | | |InplaceDimShuffle{x} [id N]
| | | | | | |TensorConstant{1.0} [id O]
| | | | | |ScalarConstant{0} [id P]
| | | | |Subtensor{int64} [id Q]
| | | | |Shape [id R]
| | | | | |Rebroadcast{(0, False)} [id J]
| | | | |ScalarConstant{1} [id S]
| | | |Rebroadcast{(0, False)} [id J]
| | | |ScalarFromTensor [id T]
| | | |Subtensor{int64} [id H]
| | |A [id M] (outer_in_non_seqs-0)
| |ScalarConstant{1} [id U]
|ScalarConstant{-1} [id V]
Inner graphs:
for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id W] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id X] -> [id E] (inner_in_sit_sot-0)
> |A_copy [id Y] -> [id M] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
def test_debugprint_sitsot_no_extra_info():
k = iscalar("k")
A = dvector("A")
# Symbolic description of the result
result, updates = aesara.scan(
fn=lambda prior_result, A: prior_result * A,
outputs_info=at.ones_like(A),
non_sequences=A,
n_steps=k,
)
final_result = result[-1]
output_str = debugprint(final_result, file="str", print_op_info=False)
lines = output_str.split("\n")
expected_output = """Subtensor{int64} [id A]
|Subtensor{int64::} [id B]
| |for{cpu,scan_fn} [id C]
| | |k [id D] | | |k [id D]
| | |IncSubtensor{Set;:int64:} [id E] '' | | |IncSubtensor{Set;:int64:} [id E]
| | | |AllocEmpty{dtype='float64'} [id F] '' | | | |AllocEmpty{dtype='float64'} [id F]
| | | | |Elemwise{add,no_inplace} [id G] '' | | | | |Elemwise{add,no_inplace} [id G]
| | | | | |k [id D] | | | | | |k [id D]
| | | | | |Subtensor{int64} [id H] '' | | | | | |Subtensor{int64} [id H]
| | | | | |Shape [id I] '' | | | | | |Shape [id I]
| | | | | | |Rebroadcast{(0, False)} [id J] '' | | | | | | |Rebroadcast{(0, False)} [id J]
| | | | | | |InplaceDimShuffle{x,0} [id K] '' | | | | | | |InplaceDimShuffle{x,0} [id K]
| | | | | | |Elemwise{second,no_inplace} [id L] '' | | | | | | |Elemwise{second,no_inplace} [id L]
| | | | | | |A [id M] | | | | | | |A [id M]
| | | | | | |InplaceDimShuffle{x} [id N] '' | | | | | | |InplaceDimShuffle{x} [id N]
| | | | | | |TensorConstant{1.0} [id O] | | | | | | |TensorConstant{1.0} [id O]
| | | | | |ScalarConstant{0} [id P] | | | | | |ScalarConstant{0} [id P]
| | | | |Subtensor{int64} [id Q] '' | | | | |Subtensor{int64} [id Q]
| | | | |Shape [id R] '' | | | | |Shape [id R]
| | | | | |Rebroadcast{(0, False)} [id J] '' | | | | | |Rebroadcast{(0, False)} [id J]
| | | | |ScalarConstant{1} [id S] | | | | |ScalarConstant{1} [id S]
| | | |Rebroadcast{(0, False)} [id J] '' | | | |Rebroadcast{(0, False)} [id J]
| | | |ScalarFromTensor [id T] '' | | | |ScalarFromTensor [id T]
| | | |Subtensor{int64} [id H] '' | | | |Subtensor{int64} [id H]
| | |A [id M] | | |A [id M]
| |ScalarConstant{1} [id U] | |ScalarConstant{1} [id U]
|ScalarConstant{-1} [id V] |ScalarConstant{-1} [id V]
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id C] '' for{cpu,scan_fn} [id C]
>Elemwise{mul,no_inplace} [id W] '' >Elemwise{mul,no_inplace} [id W]
> |<TensorType(float64, (None,))> [id X] -> [id E] > |<TensorType(float64, (None,))> [id X] -> [id E]
> |A_copy [id Y] -> [id M]""" > |A_copy [id Y] -> [id M]"""
...@@ -62,7 +117,7 @@ def test_scan_debugprint1(): ...@@ -62,7 +117,7 @@ def test_scan_debugprint1():
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
def test_scan_debugprint2(): def test_debugprint_nitsot():
coefficients = vector("coefficients") coefficients = vector("coefficients")
x = scalar("x") x = scalar("x")
...@@ -79,52 +134,52 @@ def test_scan_debugprint2(): ...@@ -79,52 +134,52 @@ def test_scan_debugprint2():
# Sum them up # Sum them up
polynomial = components.sum() polynomial = components.sum()
output_str = debugprint(polynomial, file="str") output_str = debugprint(polynomial, file="str", print_op_info=True)
lines = output_str.split("\n") lines = output_str.split("\n")
expected_output = """Sum{acc_dtype=float64} [id A] '' expected_output = """Sum{acc_dtype=float64} [id A]
|for{cpu,scan_fn} [id B] '' |for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
|Elemwise{scalar_minimum,no_inplace} [id C] '' |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0)
| |Subtensor{int64} [id D] '' | |Subtensor{int64} [id D]
| | |Shape [id E] '' | | |Shape [id E]
| | | |Subtensor{int64::} [id F] 'coefficients[0:]' | | | |Subtensor{int64::} [id F] 'coefficients[0:]'
| | | |coefficients [id G] | | | |coefficients [id G]
| | | |ScalarConstant{0} [id H] | | | |ScalarConstant{0} [id H]
| | |ScalarConstant{0} [id I] | | |ScalarConstant{0} [id I]
| |Subtensor{int64} [id J] '' | |Subtensor{int64} [id J]
| |Shape [id K] '' | |Shape [id K]
| | |Subtensor{int64::} [id L] '' | | |Subtensor{int64::} [id L]
| | |ARange{dtype='int64'} [id M] '' | | |ARange{dtype='int64'} [id M]
| | | |TensorConstant{0} [id N] | | | |TensorConstant{0} [id N]
| | | |TensorConstant{10000} [id O] | | | |TensorConstant{10000} [id O]
| | | |TensorConstant{1} [id P] | | | |TensorConstant{1} [id P]
| | |ScalarConstant{0} [id Q] | | |ScalarConstant{0} [id Q]
| |ScalarConstant{0} [id R] | |ScalarConstant{0} [id R]
|Subtensor{:int64:} [id S] '' |Subtensor{:int64:} [id S] (outer_in_seqs-0)
| |Subtensor{int64::} [id F] 'coefficients[0:]' | |Subtensor{int64::} [id F] 'coefficients[0:]'
| |ScalarFromTensor [id T] '' | |ScalarFromTensor [id T]
| |Elemwise{scalar_minimum,no_inplace} [id C] '' | |Elemwise{scalar_minimum,no_inplace} [id C]
|Subtensor{:int64:} [id U] '' |Subtensor{:int64:} [id U] (outer_in_seqs-1)
| |Subtensor{int64::} [id L] '' | |Subtensor{int64::} [id L]
| |ScalarFromTensor [id V] '' | |ScalarFromTensor [id V]
| |Elemwise{scalar_minimum,no_inplace} [id C] '' | |Elemwise{scalar_minimum,no_inplace} [id C]
|Elemwise{scalar_minimum,no_inplace} [id C] '' |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0)
|x [id W] |x [id W] (outer_in_non_seqs-0)
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id B] '' for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
>Elemwise{mul,no_inplace} [id X] '' >Elemwise{mul,no_inplace} [id X] (inner_out_nit_sot-0)
> |coefficients[t] [id Y] -> [id S] > |coefficients[t] [id Y] -> [id S] (inner_in_seqs-0)
> |Elemwise{pow,no_inplace} [id Z] '' > |Elemwise{pow,no_inplace} [id Z]
> |x_copy [id BA] -> [id W] > |x_copy [id BA] -> [id W] (inner_in_non_seqs-0)
> |<TensorType(int64, ())> [id BB] -> [id U]""" > |<TensorType(int64, ())> [id BB] -> [id U] (inner_in_seqs-1)"""
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
def test_scan_debugprint3(): def test_debugprint_nested_scans():
coefficients = dvector("coefficients") coefficients = dvector("coefficients")
max_coefficients_supported = 10 max_coefficients_supported = 10
...@@ -158,86 +213,86 @@ def test_scan_debugprint3(): ...@@ -158,86 +213,86 @@ def test_scan_debugprint3():
final_result = polynomial final_result = polynomial
output_str = debugprint(final_result, file="str") output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n") lines = output_str.split("\n")
expected_output = """Sum{acc_dtype=float64} [id A] '' expected_output = """Sum{acc_dtype=float64} [id A]
|for{cpu,scan_fn} [id B] '' |for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
|Elemwise{scalar_minimum,no_inplace} [id C] '' |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0)
| |Subtensor{int64} [id D] '' | |Subtensor{int64} [id D]
| | |Shape [id E] '' | | |Shape [id E]
| | | |Subtensor{int64::} [id F] 'coefficients[0:]' | | | |Subtensor{int64::} [id F] 'coefficients[0:]'
| | | |coefficients [id G] | | | |coefficients [id G]
| | | |ScalarConstant{0} [id H] | | | |ScalarConstant{0} [id H]
| | |ScalarConstant{0} [id I] | | |ScalarConstant{0} [id I]
| |Subtensor{int64} [id J] '' | |Subtensor{int64} [id J]
| |Shape [id K] '' | |Shape [id K]
| | |Subtensor{int64::} [id L] '' | | |Subtensor{int64::} [id L]
| | |ARange{dtype='int64'} [id M] '' | | |ARange{dtype='int64'} [id M]
| | | |TensorConstant{0} [id N] | | | |TensorConstant{0} [id N]
| | | |TensorConstant{10} [id O] | | | |TensorConstant{10} [id O]
| | | |TensorConstant{1} [id P] | | | |TensorConstant{1} [id P]
| | |ScalarConstant{0} [id Q] | | |ScalarConstant{0} [id Q]
| |ScalarConstant{0} [id R] | |ScalarConstant{0} [id R]
|Subtensor{:int64:} [id S] '' |Subtensor{:int64:} [id S] (outer_in_seqs-0)
| |Subtensor{int64::} [id F] 'coefficients[0:]' | |Subtensor{int64::} [id F] 'coefficients[0:]'
| |ScalarFromTensor [id T] '' | |ScalarFromTensor [id T]
| |Elemwise{scalar_minimum,no_inplace} [id C] '' | |Elemwise{scalar_minimum,no_inplace} [id C]
|Subtensor{:int64:} [id U] '' |Subtensor{:int64:} [id U] (outer_in_seqs-1)
| |Subtensor{int64::} [id L] '' | |Subtensor{int64::} [id L]
| |ScalarFromTensor [id V] '' | |ScalarFromTensor [id V]
| |Elemwise{scalar_minimum,no_inplace} [id C] '' | |Elemwise{scalar_minimum,no_inplace} [id C]
|Elemwise{scalar_minimum,no_inplace} [id C] '' |Elemwise{scalar_minimum,no_inplace} [id C] (outer_in_nit_sot-0)
|A [id W] |A [id W] (outer_in_non_seqs-0)
|k [id X] |k [id X] (outer_in_non_seqs-1)
Inner graphs: Inner graphs:
for{cpu,scan_fn} [id B] '' for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
>Elemwise{mul,no_inplace} [id Y] '' >Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0)
> |InplaceDimShuffle{x} [id Z] '' > |InplaceDimShuffle{x} [id Z]
> | |coefficients[t] [id BA] -> [id S] > | |coefficients[t] [id BA] -> [id S] (inner_in_seqs-0)
> |Elemwise{pow,no_inplace} [id BB] '' > |Elemwise{pow,no_inplace} [id BB]
> |Subtensor{int64} [id BC] '' > |Subtensor{int64} [id BC]
> | |Subtensor{int64::} [id BD] '' > | |Subtensor{int64::} [id BD]
> | | |for{cpu,scan_fn} [id BE] '' > | | |for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
> | | | |k_copy [id BF] -> [id X] > | | | |k_copy [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
> | | | |IncSubtensor{Set;:int64:} [id BG] '' > | | | |IncSubtensor{Set;:int64:} [id BG] (outer_in_sit_sot-0)
> | | | | |AllocEmpty{dtype='float64'} [id BH] '' > | | | | |AllocEmpty{dtype='float64'} [id BH]
> | | | | | |Elemwise{add,no_inplace} [id BI] '' > | | | | | |Elemwise{add,no_inplace} [id BI]
> | | | | | | |k_copy [id BF] -> [id X] > | | | | | | |k_copy [id BF] -> [id X] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BJ] '' > | | | | | | |Subtensor{int64} [id BJ]
> | | | | | | |Shape [id BK] '' > | | | | | | |Shape [id BK]
> | | | | | | | |Rebroadcast{(0, False)} [id BL] '' > | | | | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | | | | |InplaceDimShuffle{x,0} [id BM] '' > | | | | | | | |InplaceDimShuffle{x,0} [id BM]
> | | | | | | | |Elemwise{second,no_inplace} [id BN] '' > | | | | | | | |Elemwise{second,no_inplace} [id BN]
> | | | | | | | |A_copy [id BO] -> [id W] > | | | | | | | |A_copy [id BO] -> [id W] (inner_in_non_seqs-0)
> | | | | | | | |InplaceDimShuffle{x} [id BP] '' > | | | | | | | |InplaceDimShuffle{x} [id BP]
> | | | | | | | |TensorConstant{1.0} [id BQ] > | | | | | | | |TensorConstant{1.0} [id BQ]
> | | | | | | |ScalarConstant{0} [id BR] > | | | | | | |ScalarConstant{0} [id BR]
> | | | | | |Subtensor{int64} [id BS] '' > | | | | | |Subtensor{int64} [id BS]
> | | | | | |Shape [id BT] '' > | | | | | |Shape [id BT]
> | | | | | | |Rebroadcast{(0, False)} [id BL] '' > | | | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | | |ScalarConstant{1} [id BU] > | | | | | |ScalarConstant{1} [id BU]
> | | | | |Rebroadcast{(0, False)} [id BL] '' > | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | |ScalarFromTensor [id BV] '' > | | | | |ScalarFromTensor [id BV]
> | | | | |Subtensor{int64} [id BJ] '' > | | | | |Subtensor{int64} [id BJ]
> | | | |A_copy [id BO] -> [id W] > | | | |A_copy [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
> | | |ScalarConstant{1} [id BW] > | | |ScalarConstant{1} [id BW]
> | |ScalarConstant{-1} [id BX] > | |ScalarConstant{-1} [id BX]
> |InplaceDimShuffle{x} [id BY] '' > |InplaceDimShuffle{x} [id BY]
> |<TensorType(int64, ())> [id BZ] -> [id U] > |<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
for{cpu,scan_fn} [id BE] '' for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CA] '' >Elemwise{mul,no_inplace} [id CA] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id CB] -> [id BG] > |<TensorType(float64, (None,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
> |A_copy [id CC] -> [id BO]""" > |A_copy [id CC] -> [id BO] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
def test_scan_debugprint4(): def test_debugprint_mitsot():
def fn(a_m2, a_m1, b_m2, b_m1): def fn(a_m2, a_m1, b_m2, b_m1):
return a_m1 + a_m2, b_m1 + b_m2 return a_m1 + a_m2, b_m1 + b_m2
...@@ -254,63 +309,63 @@ def test_scan_debugprint4(): ...@@ -254,63 +309,63 @@ def test_scan_debugprint4():
) )
final_result = a + b final_result = a + b
output_str = debugprint(final_result, file="str") output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n") lines = output_str.split("\n")
expected_output = """Elemwise{add,no_inplace} [id A] '' expected_output = """Elemwise{add,no_inplace} [id A]
|Subtensor{int64::} [id B] '' |Subtensor{int64::} [id B]
| |for{cpu,scan_fn}.0 [id C] '' | |for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
| | |TensorConstant{5} [id D] | | |TensorConstant{5} [id D] (n_steps)
| | |IncSubtensor{Set;:int64:} [id E] '' | | |IncSubtensor{Set;:int64:} [id E] (outer_in_mit_sot-0)
| | | |AllocEmpty{dtype='int64'} [id F] '' | | | |AllocEmpty{dtype='int64'} [id F]
| | | | |Elemwise{add,no_inplace} [id G] '' | | | | |Elemwise{add,no_inplace} [id G]
| | | | |TensorConstant{5} [id D] | | | | |TensorConstant{5} [id D]
| | | | |Subtensor{int64} [id H] '' | | | | |Subtensor{int64} [id H]
| | | | |Shape [id I] '' | | | | |Shape [id I]
| | | | | |Subtensor{:int64:} [id J] '' | | | | | |Subtensor{:int64:} [id J]
| | | | | |<TensorType(int64, (None,))> [id K] | | | | | |<TensorType(int64, (None,))> [id K]
| | | | | |ScalarConstant{2} [id L] | | | | | |ScalarConstant{2} [id L]
| | | | |ScalarConstant{0} [id M] | | | | |ScalarConstant{0} [id M]
| | | |Subtensor{:int64:} [id J] '' | | | |Subtensor{:int64:} [id J]
| | | |ScalarFromTensor [id N] '' | | | |ScalarFromTensor [id N]
| | | |Subtensor{int64} [id H] '' | | | |Subtensor{int64} [id H]
| | |IncSubtensor{Set;:int64:} [id O] '' | | |IncSubtensor{Set;:int64:} [id O] (outer_in_mit_sot-1)
| | |AllocEmpty{dtype='int64'} [id P] '' | | |AllocEmpty{dtype='int64'} [id P]
| | | |Elemwise{add,no_inplace} [id Q] '' | | | |Elemwise{add,no_inplace} [id Q]
| | | |TensorConstant{5} [id D] | | | |TensorConstant{5} [id D]
| | | |Subtensor{int64} [id R] '' | | | |Subtensor{int64} [id R]
| | | |Shape [id S] '' | | | |Shape [id S]
| | | | |Subtensor{:int64:} [id T] '' | | | | |Subtensor{:int64:} [id T]
| | | | |<TensorType(int64, (None,))> [id U] | | | | |<TensorType(int64, (None,))> [id U]
| | | | |ScalarConstant{2} [id V] | | | | |ScalarConstant{2} [id V]
| | | |ScalarConstant{0} [id W] | | | |ScalarConstant{0} [id W]
| | |Subtensor{:int64:} [id T] '' | | |Subtensor{:int64:} [id T]
| | |ScalarFromTensor [id X] '' | | |ScalarFromTensor [id X]
| | |Subtensor{int64} [id R] '' | | |Subtensor{int64} [id R]
| |ScalarConstant{2} [id Y] | |ScalarConstant{2} [id Y]
|Subtensor{int64::} [id Z] '' |Subtensor{int64::} [id Z]
|for{cpu,scan_fn}.1 [id C] '' |for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
|ScalarConstant{2} [id BA] |ScalarConstant{2} [id BA]
Inner graphs: Inner graphs:
for{cpu,scan_fn}.0 [id C] '' for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
>Elemwise{add,no_inplace} [id BB] '' >Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
> |<TensorType(int64, ())> [id BC] -> [id E] > |<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
> |<TensorType(int64, ())> [id BD] -> [id E] > |<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
>Elemwise{add,no_inplace} [id BE] '' >Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1)
> |<TensorType(int64, ())> [id BF] -> [id O] > |<TensorType(int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
> |<TensorType(int64, ())> [id BG] -> [id O] > |<TensorType(int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)
for{cpu,scan_fn}.1 [id C] '' for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
>Elemwise{add,no_inplace} [id BB] '' >Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
>Elemwise{add,no_inplace} [id BE] ''""" >Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1)"""
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
def test_scan_debugprint5(): def test_debugprint_mitmot():
k = iscalar("k") k = iscalar("k")
A = dvector("A") A = dvector("A")
...@@ -325,129 +380,184 @@ def test_scan_debugprint5(): ...@@ -325,129 +380,184 @@ def test_scan_debugprint5():
final_result = aesara.grad(result[-1].sum(), A) final_result = aesara.grad(result[-1].sum(), A)
output_str = debugprint(final_result, file="str") output_str = debugprint(final_result, file="str", print_op_info=True)
lines = output_str.split("\n") lines = output_str.split("\n")
expected_output = """Subtensor{int64} [id A] '' expected_output = """Subtensor{int64} [id A]
|for{cpu,grad_of_scan_fn}.1 [id B] '' |for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
| |Elemwise{sub,no_inplace} [id C] '' | |Elemwise{sub,no_inplace} [id C] (n_steps)
| | |Subtensor{int64} [id D] '' | | |Subtensor{int64} [id D]
| | | |Shape [id E] '' | | | |Shape [id E]
| | | | |for{cpu,scan_fn} [id F] '' | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
| | | | |k [id G] | | | | |k [id G] (n_steps)
| | | | |IncSubtensor{Set;:int64:} [id H] '' | | | | |IncSubtensor{Set;:int64:} [id H] (outer_in_sit_sot-0)
| | | | | |AllocEmpty{dtype='float64'} [id I] '' | | | | | |AllocEmpty{dtype='float64'} [id I]
| | | | | | |Elemwise{add,no_inplace} [id J] '' | | | | | | |Elemwise{add,no_inplace} [id J]
| | | | | | | |k [id G] | | | | | | | |k [id G]
| | | | | | | |Subtensor{int64} [id K] '' | | | | | | | |Subtensor{int64} [id K]
| | | | | | | |Shape [id L] '' | | | | | | | |Shape [id L]
| | | | | | | | |Rebroadcast{(0, False)} [id M] '' | | | | | | | | |Rebroadcast{(0, False)} [id M]
| | | | | | | | |InplaceDimShuffle{x,0} [id N] '' | | | | | | | | |InplaceDimShuffle{x,0} [id N]
| | | | | | | | |Elemwise{second,no_inplace} [id O] '' | | | | | | | | |Elemwise{second,no_inplace} [id O]
| | | | | | | | |A [id P] | | | | | | | | |A [id P]
| | | | | | | | |InplaceDimShuffle{x} [id Q] '' | | | | | | | | |InplaceDimShuffle{x} [id Q]
| | | | | | | | |TensorConstant{1.0} [id R] | | | | | | | | |TensorConstant{1.0} [id R]
| | | | | | | |ScalarConstant{0} [id S] | | | | | | | |ScalarConstant{0} [id S]
| | | | | | |Subtensor{int64} [id T] '' | | | | | | |Subtensor{int64} [id T]
| | | | | | |Shape [id U] '' | | | | | | |Shape [id U]
| | | | | | | |Rebroadcast{(0, False)} [id M] '' | | | | | | | |Rebroadcast{(0, False)} [id M]
| | | | | | |ScalarConstant{1} [id V] | | | | | | |ScalarConstant{1} [id V]
| | | | | |Rebroadcast{(0, False)} [id M] '' | | | | | |Rebroadcast{(0, False)} [id M]
| | | | | |ScalarFromTensor [id W] '' | | | | | |ScalarFromTensor [id W]
| | | | | |Subtensor{int64} [id K] '' | | | | | |Subtensor{int64} [id K]
| | | | |A [id P] | | | | |A [id P] (outer_in_non_seqs-0)
| | | |ScalarConstant{0} [id X] | | | |ScalarConstant{0} [id X]
| | |TensorConstant{1} [id Y] | | |TensorConstant{1} [id Y]
| |Subtensor{:int64:} [id Z] '' | |Subtensor{:int64:} [id Z] (outer_in_seqs-0)
| | |Subtensor{::int64} [id BA] '' | | |Subtensor{::int64} [id BA]
| | | |Subtensor{:int64:} [id BB] '' | | | |Subtensor{:int64:} [id BB]
| | | | |for{cpu,scan_fn} [id F] '' | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
| | | | |ScalarConstant{-1} [id BC] | | | | |ScalarConstant{-1} [id BC]
| | | |ScalarConstant{-1} [id BD] | | | |ScalarConstant{-1} [id BD]
| | |ScalarFromTensor [id BE] '' | | |ScalarFromTensor [id BE]
| | |Elemwise{sub,no_inplace} [id C] '' | | |Elemwise{sub,no_inplace} [id C]
| |Subtensor{:int64:} [id BF] '' | |Subtensor{:int64:} [id BF] (outer_in_seqs-1)
| | |Subtensor{:int64:} [id BG] '' | | |Subtensor{:int64:} [id BG]
| | | |Subtensor{::int64} [id BH] '' | | | |Subtensor{::int64} [id BH]
| | | | |for{cpu,scan_fn} [id F] '' | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
| | | | |ScalarConstant{-1} [id BI] | | | | |ScalarConstant{-1} [id BI]
| | | |ScalarConstant{-1} [id BJ] | | | |ScalarConstant{-1} [id BJ]
| | |ScalarFromTensor [id BK] '' | | |ScalarFromTensor [id BK]
| | |Elemwise{sub,no_inplace} [id C] '' | | |Elemwise{sub,no_inplace} [id C]
| |Subtensor{::int64} [id BL] '' | |Subtensor{::int64} [id BL] (outer_in_mit_mot-0)
| | |IncSubtensor{Inc;int64::} [id BM] '' | | |IncSubtensor{Inc;int64::} [id BM]
| | | |Elemwise{second,no_inplace} [id BN] '' | | | |Elemwise{second,no_inplace} [id BN]
| | | | |for{cpu,scan_fn} [id F] '' | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
| | | | |InplaceDimShuffle{x,x} [id BO] '' | | | | |InplaceDimShuffle{x,x} [id BO]
| | | | |TensorConstant{0.0} [id BP] | | | | |TensorConstant{0.0} [id BP]
| | | |IncSubtensor{Inc;int64} [id BQ] '' | | | |IncSubtensor{Inc;int64} [id BQ]
| | | | |Elemwise{second,no_inplace} [id BR] '' | | | | |Elemwise{second,no_inplace} [id BR]
| | | | | |Subtensor{int64::} [id BS] '' | | | | | |Subtensor{int64::} [id BS]
| | | | | | |for{cpu,scan_fn} [id F] '' | | | | | | |for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
| | | | | | |ScalarConstant{1} [id BT] | | | | | | |ScalarConstant{1} [id BT]
| | | | | |InplaceDimShuffle{x,x} [id BU] '' | | | | | |InplaceDimShuffle{x,x} [id BU]
| | | | | |TensorConstant{0.0} [id BV] | | | | | |TensorConstant{0.0} [id BV]
| | | | |Elemwise{second} [id BW] '' | | | | |Elemwise{second} [id BW]
| | | | | |Subtensor{int64} [id BX] '' | | | | | |Subtensor{int64} [id BX]
| | | | | | |Subtensor{int64::} [id BS] '' | | | | | | |Subtensor{int64::} [id BS]
| | | | | | |ScalarConstant{-1} [id BY] | | | | | | |ScalarConstant{-1} [id BY]
| | | | | |InplaceDimShuffle{x} [id BZ] '' | | | | | |InplaceDimShuffle{x} [id BZ]
| | | | | |Elemwise{second,no_inplace} [id CA] '' | | | | | |Elemwise{second,no_inplace} [id CA]
| | | | | |Sum{acc_dtype=float64} [id CB] '' | | | | | |Sum{acc_dtype=float64} [id CB]
| | | | | | |Subtensor{int64} [id BX] '' | | | | | | |Subtensor{int64} [id BX]
| | | | | |TensorConstant{1.0} [id CC] | | | | | |TensorConstant{1.0} [id CC]
| | | | |ScalarConstant{-1} [id BY] | | | | |ScalarConstant{-1} [id BY]
| | | |ScalarConstant{1} [id BT] | | | |ScalarConstant{1} [id BT]
| | |ScalarConstant{-1} [id CD] | | |ScalarConstant{-1} [id CD]
| |Alloc [id CE] '' | |Alloc [id CE] (outer_in_sit_sot-0)
| | |TensorConstant{0.0} [id CF] | | |TensorConstant{0.0} [id CF]
| | |Elemwise{add,no_inplace} [id CG] '' | | |Elemwise{add,no_inplace} [id CG]
| | | |Elemwise{sub,no_inplace} [id C] '' | | | |Elemwise{sub,no_inplace} [id C]
| | | |TensorConstant{1} [id CH] | | | |TensorConstant{1} [id CH]
| | |Subtensor{int64} [id CI] '' | | |Subtensor{int64} [id CI]
| | |Shape [id CJ] '' | | |Shape [id CJ]
| | | |A [id P] | | | |A [id P]
| | |ScalarConstant{0} [id CK] | | |ScalarConstant{0} [id CK]
| |A [id P] | |A [id P] (outer_in_non_seqs-0)
|ScalarConstant{-1} [id CL] |ScalarConstant{-1} [id CL]
Inner graphs: Inner graphs:
for{cpu,grad_of_scan_fn}.1 [id B] '' for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
>Elemwise{add,no_inplace} [id CM] '' >Elemwise{add,no_inplace} [id CM] (inner_out_mit_mot-0-0)
> |Elemwise{mul} [id CN] '' > |Elemwise{mul} [id CN]
> | |<TensorType(float64, (None,))> [id CO] -> [id BL] > | |<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |A_copy [id CP] -> [id P] > | |A_copy [id CP] -> [id P] (inner_in_non_seqs-0)
> |<TensorType(float64, (None,))> [id CQ] -> [id BL] > |<TensorType(float64, (None,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
>Elemwise{add,no_inplace} [id CR] '' >Elemwise{add,no_inplace} [id CR] (inner_out_sit_sot-0)
> |Elemwise{mul} [id CS] '' > |Elemwise{mul} [id CS]
> | |<TensorType(float64, (None,))> [id CO] -> [id BL] > | |<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |<TensorType(float64, (None,))> [id CT] -> [id Z] > | |<TensorType(float64, (None,))> [id CT] -> [id Z] (inner_in_seqs-0)
> |<TensorType(float64, (None,))> [id CU] -> [id CE] > |<TensorType(float64, (None,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id CT] -> [id H] (inner_in_sit_sot-0)
> |A_copy [id CP] -> [id P] (inner_in_non_seqs-0)
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] '' >Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id CT] -> [id H]
> |A_copy [id CP] -> [id P]
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] '' >Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
for{cpu,scan_fn} [id F] '' for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] '' >Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)"""
for{cpu,scan_fn} [id F] '' for truth, out in zip(expected_output.split("\n"), lines):
>Elemwise{mul,no_inplace} [id CV] '' assert truth.strip() == out.strip()
def test_debugprint_compiled_fn():
for{cpu,scan_fn} [id F] '' M = at.tensor(np.float64, shape=(20000, 2, 2))
>Elemwise{mul,no_inplace} [id CV] ''""" one = at.as_tensor(1, dtype=np.int64)
zero = at.as_tensor(0, dtype=np.int64)
def no_shared_fn(n, x_tm1, M):
p = M[n, x_tm1]
return at.switch(at.lt(zero, p[0]), one, zero)
out, updates = aesara.scan(
no_shared_fn,
outputs_info=[{"initial": zero, "taps": [-1]}],
sequences=[at.arange(M.shape[0])],
non_sequences=[M],
allow_gc=False,
mode="FAST_RUN",
)
# In this case, `debugprint` should print the compiled inner-graph
# (i.e. from `Scan._fn`)
out = aesara.function([M], out, updates=updates, mode="FAST_RUN")
expected_output = """forall_inplace,cpu,scan_fn} [id A] 2 (outer_out_sit_sot-0)
|TensorConstant{20000} [id B] (n_steps)
|TensorConstant{[ 0 ..998 19999]} [id C] (outer_in_seqs-0)
|IncSubtensor{InplaceSet;:int64:} [id D] 1 (outer_in_sit_sot-0)
| |AllocEmpty{dtype='int64'} [id E] 0
| | |TensorConstant{20000} [id B]
| |TensorConstant{(1,) of 0} [id F]
| |ScalarConstant{1} [id G]
|<TensorType(float64, (20000, 2, 2))> [id H] (outer_in_non_seqs-0)
Inner graphs:
forall_inplace,cpu,scan_fn} [id A] (outer_out_sit_sot-0)
>Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
> |TensorConstant{0} [id J]
> |Subtensor{int64, int64, int64} [id K]
> | |<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
> | |ScalarFromTensor [id M]
> | | |<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
> | |ScalarFromTensor [id O]
> | | |<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
> | |ScalarConstant{0} [id Q]
> |TensorConstant{1} [id R]
"""
output_str = debugprint(out, file="str", print_op_info=True)
lines = output_str.split("\n")
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
@pytest.mark.skipif(not pydot_imported, reason="pydot not available") @pytest.mark.skipif(not pydot_imported, reason="pydot not available")
def test_printing_scan(): def test_pydotprint():
def f_pow2(x_tm1): def f_pow2(x_tm1):
return 2 * x_tm1 return 2 * x_tm1
......
...@@ -141,11 +141,11 @@ def test_debugprint(): ...@@ -141,11 +141,11 @@ def test_debugprint():
reference = ( reference = (
"\n".join( "\n".join(
[ [
"Elemwise{add,no_inplace} [id 0] '' ", "Elemwise{add,no_inplace} [id 0]",
" |Elemwise{add,no_inplace} [id 1] 'C' ", " |Elemwise{add,no_inplace} [id 1] 'C'",
" | |A [id 2]", " | |A [id 2]",
" | |B [id 3]", " | |B [id 3]",
" |Elemwise{add,no_inplace} [id 4] '' ", " |Elemwise{add,no_inplace} [id 4]",
" |D [id 5]", " |D [id 5]",
" |E [id 6]", " |E [id 6]",
] ]
...@@ -167,11 +167,11 @@ def test_debugprint(): ...@@ -167,11 +167,11 @@ def test_debugprint():
reference = ( reference = (
"\n".join( "\n".join(
[ [
"Elemwise{add,no_inplace} [id A] '' ", "Elemwise{add,no_inplace} [id A]",
" |Elemwise{add,no_inplace} [id B] 'C' ", " |Elemwise{add,no_inplace} [id B] 'C'",
" | |A [id C]", " | |A [id C]",
" | |B [id D]", " | |B [id D]",
" |Elemwise{add,no_inplace} [id E] '' ", " |Elemwise{add,no_inplace} [id E]",
" |D [id F]", " |D [id F]",
" |E [id G]", " |E [id G]",
] ]
...@@ -193,9 +193,9 @@ def test_debugprint(): ...@@ -193,9 +193,9 @@ def test_debugprint():
reference = ( reference = (
"\n".join( "\n".join(
[ [
"Elemwise{add,no_inplace} [id A] '' ", "Elemwise{add,no_inplace} [id A]",
" |Elemwise{add,no_inplace} [id B] 'C' ", " |Elemwise{add,no_inplace} [id B] 'C'",
" |Elemwise{add,no_inplace} [id C] '' ", " |Elemwise{add,no_inplace} [id C]",
" |D [id D]", " |D [id D]",
" |E [id E]", " |E [id E]",
] ]
...@@ -217,13 +217,13 @@ def test_debugprint(): ...@@ -217,13 +217,13 @@ def test_debugprint():
reference = ( reference = (
"\n".join( "\n".join(
[ [
"Elemwise{add,no_inplace} '' ", "Elemwise{add,no_inplace}",
" |Elemwise{add,no_inplace} 'C' ", " |Elemwise{add,no_inplace} 'C'",
" | |A ", " | |A",
" | |B ", " | |B",
" |Elemwise{add,no_inplace} '' ", " |Elemwise{add,no_inplace}",
" |D ", " |D",
" |E ", " |E",
] ]
) )
+ "\n" + "\n"
...@@ -238,15 +238,14 @@ def test_debugprint(): ...@@ -238,15 +238,14 @@ def test_debugprint():
s = StringIO() s = StringIO()
debugprint(g, file=s, ids="", print_storage=True) debugprint(g, file=s, ids="", print_storage=True)
s = s.getvalue() s = s.getvalue()
# The additional white space are needed!
reference = ( reference = (
"\n".join( "\n".join(
[ [
"Elemwise{add,no_inplace} '' 0 [None]", "Elemwise{add,no_inplace} 0 [None]",
" |A [None]", " |A [None]",
" |B [None]", " |B [None]",
" |D [None]", " |D [None]",
" |E [None]", " |E [None]",
] ]
) )
+ "\n" + "\n"
...@@ -269,8 +268,8 @@ def test_debugprint_ids(): ...@@ -269,8 +268,8 @@ def test_debugprint_ids():
debugprint(e_at, ids="auto", file=s) debugprint(e_at, ids="auto", file=s)
s = s.getvalue() s = s.getvalue()
exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}] '' exp_res = f"""Elemwise{{add,no_inplace}} [id {e_at.auto_name}]
|dot [id {d_at.auto_name}] '' |dot [id {d_at.auto_name}]
| |<TensorType(float64, (None, None))> [id {b_at.auto_name}] | |<TensorType(float64, (None, None))> [id {b_at.auto_name}]
| |<TensorType(float64, (None,))> [id {a_at.auto_name}] | |<TensorType(float64, (None,))> [id {a_at.auto_name}]
|<TensorType(float64, (None,))> [id {a_at.auto_name}] |<TensorType(float64, (None,))> [id {a_at.auto_name}]
...@@ -306,13 +305,13 @@ def test_debugprint_inner_graph(): ...@@ -306,13 +305,13 @@ def test_debugprint_inner_graph():
output_str = debugprint(out, file="str") output_str = debugprint(out, file="str")
lines = output_str.split("\n") lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] '' exp_res = """MyInnerGraphOp [id A]
|3 [id B] |3 [id B]
|4 [id C] |4 [id C]
Inner graphs: Inner graphs:
MyInnerGraphOp [id A] '' MyInnerGraphOp [id A]
>op2 [id D] 'igo1' >op2 [id D] 'igo1'
> |4 [id E] > |4 [id E]
> |5 [id F] > |5 [id F]
...@@ -330,17 +329,17 @@ MyInnerGraphOp [id A] '' ...@@ -330,17 +329,17 @@ MyInnerGraphOp [id A] ''
output_str = debugprint(out_2, file="str") output_str = debugprint(out_2, file="str")
lines = output_str.split("\n") lines = output_str.split("\n")
exp_res = """MyInnerGraphOp [id A] '' exp_res = """MyInnerGraphOp [id A]
|5 [id B] |5 [id B]
Inner graphs: Inner graphs:
MyInnerGraphOp [id A] '' MyInnerGraphOp [id A]
>MyInnerGraphOp [id C] '' >MyInnerGraphOp [id C]
> |3 [id D] > |3 [id D]
> |4 [id E] > |4 [id E]
MyInnerGraphOp [id C] '' MyInnerGraphOp [id C]
>op2 [id F] 'igo1' >op2 [id F] 'igo1'
> |4 [id G] > |4 [id G]
> |5 [id H] > |5 [id H]
...@@ -371,13 +370,13 @@ def test_get_var_by_id(): ...@@ -371,13 +370,13 @@ def test_get_var_by_id():
# op1 [id A] 'o1' # op1 [id A] 'o1'
# |1 [id B] # |1 [id B]
# |2 [id C] # |2 [id C]
# MyInnerGraphOp [id D] '' # MyInnerGraphOp [id D]
# |3 [id E] # |3 [id E]
# |op1 [id A] 'o1' # |op1 [id A] 'o1'
# #
# Inner graphs: # Inner graphs:
# #
# MyInnerGraphOp [id D] '' # MyInnerGraphOp [id D]
# >op2 [id F] 'igo1' # >op2 [id F] 'igo1'
# > |4 [id G] # > |4 [id G]
# > |5 [id H] # > |5 [id H]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论