提交 b979dd6e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add FunctionGraphs to the HasInnerGraph interface and OpFromGraph, Scan

上级 980ecacf
......@@ -13,10 +13,12 @@ from aesara.gradient import DisconnectedType, Rop, grad
from aesara.graph.basic import (
Apply,
Constant,
NominalVariable,
Variable,
clone_replace,
graph_inputs,
io_connection_pattern,
replace_nominals_with_dummies,
)
from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType
......@@ -349,17 +351,32 @@ class OpFromGraph(Op, HasInnerGraph):
raise NotImplementedError("Updates and givens are not allowed here")
self.is_inline = inline
# To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = [
var for var in graph_inputs(outputs) if isinstance(var, SharedVariable)
self.shared_inputs = []
for var in graph_inputs(outputs):
if isinstance(var, SharedVariable):
self.shared_inputs.append(var)
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
# The inputs should be `NominalVariable`s, so that graphs can be merged
replacements = {}
for n, v in enumerate(inputs):
replacements[v] = NominalVariable(n, v.type)
shared_vars = [
NominalVariable(n, var.type)
for n, var in enumerate(self.shared_inputs, start=len(inputs) + 1)
]
shared_vars = [var.type() for var in self.shared_inputs]
replacements.update(dict(zip(self.shared_inputs, shared_vars)))
new = rebuild_collect_shared(
cast(Sequence[Variable], outputs),
inputs=inputs + shared_vars,
replace=dict(zip(self.shared_inputs, shared_vars)),
replace=replacements,
copy_inputs_over=False,
)
(
......@@ -374,10 +391,7 @@ class OpFromGraph(Op, HasInnerGraph):
assert not update_expr
assert not shared_inputs
self._inner_inputs = local_inputs
self._inner_outputs = local_outputs
self.inputs = inputs
self.outputs = outputs
self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs]
......@@ -778,29 +792,23 @@ class OpFromGraph(Op, HasInnerGraph):
# The shared variables are not equal to the original shared
# variables, so we construct a new `Op` that uses the new shared
# variables instead.
# All this is really doing is making the unused (internally, at
# least) `self.outputs` and `self.shared_inputs` consistent.
# We could just as easily `copy` this `Op`, update
# `self.shared_inputs`, and avoid cloning anything, but this is a
# more "change-proof" approach, because it still work when/if those
# attributes end up being used.
replace = dict(inner_and_input_shareds)
replace = dict(
zip(self.inner_inputs[num_expected_inps:], new_shared_inputs)
)
# If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step
new_inner_outputs = clone_replace(
self.outputs, replace=replace, share_inputs=True
self.inner_outputs, replace=replace, share_inputs=True
)
# `self.inputs` should not contain any shared variables, so we know
# that those are inputs to `new_outputs`, because we chose not to
# clone inputs; however, it's possible that the new shared variable
# inputs aren't actually shared variables. When they aren't we
# need to add them as new inputs.
# It's possible that the new shared variable inputs aren't actually
# shared variables. When they aren't we need to add them as new
# inputs.
unshared_inputs = [
inp for inp in new_shared_inputs if not isinstance(inp, SharedVariable)
]
new_inner_inputs = self.inputs + unshared_inputs
new_inner_inputs = self.inner_inputs[:num_expected_inps] + unshared_inputs
new_op = type(self)(
inputs=new_inner_inputs,
......@@ -901,11 +909,11 @@ class OpFromGraph(Op, HasInnerGraph):
@property
def inner_inputs(self):
return self._inner_inputs
return self.fgraph.inputs
@property
def inner_outputs(self):
return self._inner_outputs
return self.fgraph.outputs
def perform(self, node, inputs, outputs):
variables = self.fn(*inputs)
......
......@@ -13,7 +13,7 @@ import aesara
from aesara.compile.ops import ViewOp
from aesara.configdefaults import config
from aesara.graph import utils
from aesara.graph.basic import Variable
from aesara.graph.basic import NominalVariable, Variable
from aesara.graph.null_type import NullType, null_type
from aesara.graph.op import get_test_values
from aesara.graph.type import Type
......@@ -1295,15 +1295,16 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# has the right shape
if hasattr(term, "shape"):
orig_ipt = inputs[i]
for orig_ipt_v, term_v in get_test_values(orig_ipt, term):
i_shape = orig_ipt_v.shape
t_shape = term_v.shape
if i_shape != t_shape:
raise ValueError(
f"{node.op}.grad returned object of "
f"shape {t_shape} as gradient term on input {int(i)} "
f"of shape {i_shape}"
)
if not isinstance(orig_ipt, NominalVariable):
for orig_ipt_v, term_v in get_test_values(orig_ipt, term):
i_shape = orig_ipt_v.shape
t_shape = term_v.shape
if i_shape != t_shape:
raise ValueError(
f"{node.op}.grad returned object of "
f"shape {t_shape} as gradient term on input {int(i)} "
f"of shape {i_shape}"
)
if not isinstance(term.type, (NullType, DisconnectedType)):
if term.type.dtype not in aesara.tensor.type.float_dtypes:
......
......@@ -1755,3 +1755,38 @@ def get_var_by_name(
results += (var,)
return results
def replace_nominals_with_dummies(inputs, outputs):
"""Replace nominal inputs with dummy variables.
When constructing a new graph with nominal inputs from an existing graph,
pre-existing nominal inputs need to be replaced with dummy variables
beforehand; otherwise, sequential ID ordering (i.e. when nominals are IDed
based on the ordered inputs to which they correspond) of the nominals could
be broken, and/or circular replacements could manifest.
FYI: This function assumes that all the nominal variables in the subgraphs
between `inputs` and `outputs` are present in `inputs`.
"""
existing_nominal_replacements = {
i: i.type() for i in inputs if isinstance(i, NominalVariable)
}
if existing_nominal_replacements:
# Replace existing nominal variables, because we need to produce an
# inner-graph for which the nominal variable IDs correspond exactly
# to their input order
_ = clone_get_equiv(
inputs,
outputs,
copy_inputs=False,
copy_orphans=False,
memo=existing_nominal_replacements,
)
outputs = [existing_nominal_replacements[o] for o in outputs]
inputs = [existing_nominal_replacements[i] for i in inputs]
return inputs, outputs
......@@ -615,6 +615,9 @@ class _NoPythonOp(Op):
class HasInnerGraph:
r"""A mixin for an `Op` that contain an inner graph."""
fgraph: "FunctionGraph"
"""A `FunctionGraph` of the inner function."""
@property
@abstractmethod
def fn(self) -> "Function":
......
......@@ -375,6 +375,7 @@ N.B.:
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=s.owner,
)
if file is _file:
......@@ -407,6 +408,7 @@ def _debugprint(
op_information: Optional[Dict[Apply, Dict[Variable, str]]] = None,
parent_node: Optional[Apply] = None,
print_op_info: bool = False,
inner_graph_node: Optional[Apply] = None,
) -> IOBase:
r"""Print the graph leading to `r`.
......@@ -459,6 +461,8 @@ def _debugprint(
print_op_info
Print extra information provided by the relevant `Op`\s. For example,
print the tap information for `Scan` inputs and outputs.
inner_graph_node
The inner-graph node in which `r` is contained.
"""
if depth == 0:
return file
......@@ -615,6 +619,7 @@ def _debugprint(
print_op_info=print_op_info,
print_destroy_map=print_destroy_map,
print_view_map=print_view_map,
inner_graph_node=inner_graph_node,
)
else:
......@@ -644,14 +649,9 @@ def _debugprint(
var_output = f"{var_output} -> {outer_id_str}"
# This is an inner-graph input, so we need to find the outer node
# it belongs to and get the extra information from that
for inner_graph in inner_graph_ops:
if outer_r in inner_graph.owner.inputs:
node_info = op_information.get(inner_graph.owner)
if node_info and r in node_info:
var_output = f"{var_output} ({node_info[r]})"
break
node_info = op_information.get(inner_graph_node)
if node_info and r in node_info:
var_output = f"{var_output} ({node_info[r]})"
node_info = op_information.get(parent_node) or op_information.get(r.owner)
if node_info and r in node_info:
......
......@@ -54,6 +54,7 @@ import numpy as np
import aesara
from aesara import tensor as at
from aesara.compile import SharedVariable
from aesara.compile.builders import infer_shape
from aesara.compile.function import function
from aesara.compile.io import In, Out
......@@ -64,13 +65,16 @@ from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefine
from aesara.graph.basic import (
Apply,
Constant,
NominalVariable,
Variable,
clone_replace,
equal_computations,
graph_inputs,
io_connection_pattern,
replace_nominals_with_dummies,
)
from aesara.graph.features import NoOutputFromInplace
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.utils import MissingInputError
from aesara.link.c.basic import CLinker
......@@ -757,8 +761,27 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
If ``True``, all the shared variables used in the inner-graph must be provided.
"""
self.inputs = inputs
self.outputs = outputs
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
input_replacements = []
for n, v in enumerate(inputs):
if not isinstance(v, (SharedVariable, Constant)):
input_replacements.append((v, NominalVariable(n, v.type)))
assert not isinstance(v, NominalVariable)
outputs = clone_replace(outputs, replace=input_replacements)
if input_replacements:
_, inputs_ = zip(*input_replacements)
inputs = list(inputs_)
else:
inputs = []
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
self.inputs = self.fgraph.inputs
self.outputs = self.fgraph.outputs
self.info = info
self.truncate_gradient = truncate_gradient
self.name = name
......@@ -3416,8 +3439,8 @@ def _op_debug_information_Scan(op, node):
inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs
else:
inner_inputs = op.inputs
inner_outputs = op.outputs
inner_inputs = op.inner_inputs
inner_outputs = op.inner_outputs
scan_args = ScanArgs(
node.inputs,
......
......@@ -466,7 +466,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
y = shared(1.0, name="y")
test_ofg = OpFromGraph([x], [x + y], on_unused_input="ignore")
assert test_ofg.inputs == [x]
assert test_ofg.shared_inputs == [y]
out = test_ofg(x)
......@@ -478,7 +477,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0]
assert "on_unused_input" in out_new.owner.op.kwargs
assert out_new.owner.op.inputs == [x]
assert out_new.owner.op.shared_inputs == [y_clone]
out_fn = function([x], out_new)
......@@ -497,7 +495,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
y = shared(1.0, name="y")
test_ofg = OpFromGraph([x], [x + y])
assert test_ofg.inputs == [x]
assert test_ofg.shared_inputs == [y]
out = test_ofg(at.as_tensor(1.0, dtype=config.floatX))
......@@ -517,7 +514,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
y = shared(1.0, name="y")
test_ofg = OpFromGraph([], [y])
assert test_ofg.inputs == []
assert test_ofg.shared_inputs == [y]
out_1_fn = function([], test_ofg())
......@@ -526,7 +522,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert np.array_equal(res_1, 1.0)
test_ofg_new = test_ofg.make_node(x)
assert test_ofg_new.op.inputs == [x]
assert test_ofg_new.op.shared_inputs == []
out_2_fn = function([x], test_ofg_new.outputs[0])
......@@ -535,6 +530,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert np.array_equal(res_2, 1.0)
@config.change_flags(floatX="float64")
def test_debugprint():
x, y, z = matrices("xyz")
e = x + y * z
......@@ -553,10 +549,10 @@ Inner graphs:
OpFromGraph{inline=False} [id A]
>Elemwise{add,no_inplace} [id E]
> |x [id F]
> |*0-<TensorType(float64, (None, None))> [id F]
> |Elemwise{mul,no_inplace} [id G]
> |y [id H]
> |z [id I]
> |*1-<TensorType(float64, (None, None))> [id H]
> |*2-<TensorType(float64, (None, None))> [id I]
"""
for truth, out in zip(exp_res.split("\n"), lines):
......
......@@ -2355,9 +2355,11 @@ def test_compute_test_values():
assert np.array_equal(z_grad.tag.test_value, np.r_[9.0, 9.0, 9.0])
@pytest.mark.xfail(reason="NominalVariables don't support test values")
def test_compute_test_value_grad():
# Test case originally reported by Bitton Tenessi
# https://groups.google.com/d/msg/theano-users/fAP3i2CbskQ/3OgBf4yjqiQJ
"""
See https://groups.google.com/d/msg/theano-users/fAP3i2CbskQ/3OgBf4yjqiQJ
"""
WEIGHT = np.array([1, 2, 1, 3, 4, 1, 5, 6, 1, 7, 8, 1], dtype="float32")
with config.change_flags(compute_test_value="raise", exception_verbosity="high"):
......@@ -2395,10 +2397,12 @@ def test_compute_test_value_grad():
grad(loss, W_flat)
@pytest.mark.xfail(reason="NominalVariables don't support test values")
def test_compute_test_value_grad_cast():
# Test for test values when variables have to be casted
# Reported by Daniel Renshaw at
# https://groups.google.com/d/topic/theano-users/o4jK9xDe5WI/discussion
"""Test for test values when variables have to be casted.
See https://groups.google.com/d/topic/theano-users/o4jK9xDe5WI/discussion
"""
with config.change_flags(compute_test_value="raise"):
h = matrix("h")
h.tag.test_value = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=config.floatX)
......@@ -2434,7 +2438,7 @@ def test_constant_folding_n_steps():
def test_outputs_taps_check():
# Checks that errors are raised with bad output_info taps.
"""Checks that errors are raised with bad output_info taps."""
x = fvector("x")
y = fvector("y")
......@@ -2462,7 +2466,6 @@ def test_inconsistent_broadcast_error():
grad(y.sum(), x)
@pytest.mark.xfail(raises=MissingInputError)
def test_missing_input_error():
c = shared(0.0)
inc = scalar("inc")
......@@ -2470,8 +2473,8 @@ def test_missing_input_error():
def count_up():
return at.zeros(()), {c: c + inc}
_, updates = scan(count_up, n_steps=20)
function(inputs=[inc], outputs=[], updates=updates)
with pytest.raises(MissingInputError):
_, updates = scan(count_up, n_steps=20)
class TestGradUntil:
......
......@@ -3,10 +3,12 @@ import pytest
import aesara
import aesara.tensor as at
from aesara.configdefaults import config
from aesara.printing import debugprint, pydot_imported, pydotprint
from aesara.tensor.type import dvector, iscalar, scalar, vector
@config.change_flags(floatX="float64")
def test_debugprint_sitsot():
k = iscalar("k")
A = dvector("A")
......@@ -55,8 +57,8 @@ def test_debugprint_sitsot():
for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id W] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id X] -> [id E] (inner_in_sit_sot-0)
> |A_copy [id Y] -> [id M] (inner_in_non_seqs-0)"""
> |*0-<TensorType(float64, (None,))> [id X] -> [id E] (inner_in_sit_sot-0)
> |*1-<TensorType(float64, (None,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
......@@ -110,13 +112,14 @@ def test_debugprint_sitsot_no_extra_info():
for{cpu,scan_fn} [id C]
>Elemwise{mul,no_inplace} [id W]
> |<TensorType(float64, (None,))> [id X] -> [id E]
> |A_copy [id Y] -> [id M]"""
> |*0-<TensorType(float64, (None,))> [id X] -> [id E]
> |*1-<TensorType(float64, (None,))> [id Y] -> [id M]"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@config.change_flags(floatX="float64")
def test_debugprint_nitsot():
coefficients = vector("coefficients")
x = scalar("x")
......@@ -170,15 +173,16 @@ def test_debugprint_nitsot():
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
>Elemwise{mul,no_inplace} [id X] (inner_out_nit_sot-0)
> |coefficients[t] [id Y] -> [id S] (inner_in_seqs-0)
> |*0-<TensorType(float64, ())> [id Y] -> [id S] (inner_in_seqs-0)
> |Elemwise{pow,no_inplace} [id Z]
> |x_copy [id BA] -> [id W] (inner_in_non_seqs-0)
> |<TensorType(int64, ())> [id BB] -> [id U] (inner_in_seqs-1)"""
> |*2-<TensorType(float64, ())> [id BA] -> [id W] (inner_in_non_seqs-0)
> |*1-<TensorType(int64, ())> [id BB] -> [id U] (inner_in_seqs-1)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@config.change_flags(floatX="float64")
def test_debugprint_nested_scans():
coefficients = dvector("coefficients")
max_coefficients_supported = 10
......@@ -251,22 +255,22 @@ def test_debugprint_nested_scans():
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
>Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0)
> |InplaceDimShuffle{x} [id Z]
> | |coefficients[t] [id BA] -> [id S] (inner_in_seqs-0)
> | |*0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
> |Elemwise{pow,no_inplace} [id BB]
> |Subtensor{int64} [id BC]
> | |Subtensor{int64::} [id BD]
> | | |for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
> | | | |k_copy [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
> | | | |*3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
> | | | |IncSubtensor{Set;:int64:} [id BG] (outer_in_sit_sot-0)
> | | | | |AllocEmpty{dtype='float64'} [id BH]
> | | | | | |Elemwise{add,no_inplace} [id BI]
> | | | | | | |k_copy [id BF] -> [id X] (inner_in_non_seqs-1)
> | | | | | | |*3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BJ]
> | | | | | | |Shape [id BK]
> | | | | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | | | | |InplaceDimShuffle{x,0} [id BM]
> | | | | | | | |Elemwise{second,no_inplace} [id BN]
> | | | | | | | |A_copy [id BO] -> [id W] (inner_in_non_seqs-0)
> | | | | | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0)
> | | | | | | | |InplaceDimShuffle{x} [id BP]
> | | | | | | | |TensorConstant{1.0} [id BQ]
> | | | | | | |ScalarConstant{0} [id BR]
......@@ -277,21 +281,22 @@ def test_debugprint_nested_scans():
> | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | |ScalarFromTensor [id BV]
> | | | | |Subtensor{int64} [id BJ]
> | | | |A_copy [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
> | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
> | | |ScalarConstant{1} [id BW]
> | |ScalarConstant{-1} [id BX]
> |InplaceDimShuffle{x} [id BY]
> |<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
> |*1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CA] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
> |A_copy [id CC] -> [id BO] (inner_in_non_seqs-0)"""
> |*0-<TensorType(float64, (None,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
> |*1-<TensorType(float64, (None,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip()
@config.change_flags(floatX="float64")
def test_debugprint_mitsot():
def fn(a_m2, a_m1, b_m2, b_m1):
return a_m1 + a_m2, b_m1 + b_m2
......@@ -351,11 +356,11 @@ def test_debugprint_mitsot():
for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
>Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
> |<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
> |<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
> |*1-<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
> |*0-<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
>Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1)
> |<TensorType(int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
> |<TensorType(int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)
> |*3-<TensorType(int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
> |*2-<TensorType(int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)
for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
>Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
......@@ -365,6 +370,7 @@ def test_debugprint_mitsot():
assert truth.strip() == out.strip()
@config.change_flags(floatX="float64")
def test_debugprint_mitmot():
k = iscalar("k")
......@@ -471,19 +477,19 @@ def test_debugprint_mitmot():
for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
>Elemwise{add,no_inplace} [id CM] (inner_out_mit_mot-0-0)
> |Elemwise{mul} [id CN]
> | |<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |A_copy [id CP] -> [id P] (inner_in_non_seqs-0)
> |<TensorType(float64, (None,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
> | |*2-<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |*5-<TensorType(float64, (None,))> [id CP] -> [id P] (inner_in_non_seqs-0)
> |*3-<TensorType(float64, (None,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
>Elemwise{add,no_inplace} [id CR] (inner_out_sit_sot-0)
> |Elemwise{mul} [id CS]
> | |<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |<TensorType(float64, (None,))> [id CT] -> [id Z] (inner_in_seqs-0)
> |<TensorType(float64, (None,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
> | |*2-<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |*0-<TensorType(float64, (None,))> [id CT] -> [id Z] (inner_in_seqs-0)
> |*4-<TensorType(float64, (None,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id CT] -> [id H] (inner_in_sit_sot-0)
> |A_copy [id CP] -> [id P] (inner_in_non_seqs-0)
> |*0-<TensorType(float64, (None,))> [id CT] -> [id H] (inner_in_sit_sot-0)
> |*1-<TensorType(float64, (None,))> [id CW] -> [id P] (inner_in_non_seqs-0)
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
......@@ -540,11 +546,11 @@ def test_debugprint_compiled_fn():
>Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
> |TensorConstant{0} [id J]
> |Subtensor{int64, int64, int64} [id K]
> | |<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
> | |*2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
> | |ScalarFromTensor [id M]
> | | |<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
> | | |*0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
> | |ScalarFromTensor [id O]
> | | |<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
> | | |*1-<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
> | |ScalarConstant{0} [id Q]
> |TensorConstant{1} [id R]
"""
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论