提交 b979dd6e authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Add FunctionGraphs to the HasInnerGraph interface and OpFromGraph, Scan

上级 980ecacf
...@@ -13,10 +13,12 @@ from aesara.gradient import DisconnectedType, Rop, grad ...@@ -13,10 +13,12 @@ from aesara.gradient import DisconnectedType, Rop, grad
from aesara.graph.basic import ( from aesara.graph.basic import (
Apply, Apply,
Constant, Constant,
NominalVariable,
Variable, Variable,
clone_replace, clone_replace,
graph_inputs, graph_inputs,
io_connection_pattern, io_connection_pattern,
replace_nominals_with_dummies,
) )
from aesara.graph.fg import FunctionGraph from aesara.graph.fg import FunctionGraph
from aesara.graph.null_type import NullType from aesara.graph.null_type import NullType
...@@ -349,17 +351,32 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -349,17 +351,32 @@ class OpFromGraph(Op, HasInnerGraph):
raise NotImplementedError("Updates and givens are not allowed here") raise NotImplementedError("Updates and givens are not allowed here")
self.is_inline = inline self.is_inline = inline
# To correctly support shared variables the inner fct should # To correctly support shared variables the inner fct should
# not see them. Otherwise there is a problem with the gradient. # not see them. Otherwise there is a problem with the gradient.
self.shared_inputs = [ self.shared_inputs = []
var for var in graph_inputs(outputs) if isinstance(var, SharedVariable) for var in graph_inputs(outputs):
if isinstance(var, SharedVariable):
self.shared_inputs.append(var)
inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
# The inputs should be `NominalVariable`s, so that graphs can be merged
replacements = {}
for n, v in enumerate(inputs):
replacements[v] = NominalVariable(n, v.type)
shared_vars = [
NominalVariable(n, var.type)
for n, var in enumerate(self.shared_inputs, start=len(inputs) + 1)
] ]
shared_vars = [var.type() for var in self.shared_inputs]
replacements.update(dict(zip(self.shared_inputs, shared_vars)))
new = rebuild_collect_shared( new = rebuild_collect_shared(
cast(Sequence[Variable], outputs), cast(Sequence[Variable], outputs),
inputs=inputs + shared_vars, inputs=inputs + shared_vars,
replace=dict(zip(self.shared_inputs, shared_vars)), replace=replacements,
copy_inputs_over=False, copy_inputs_over=False,
) )
( (
...@@ -374,10 +391,7 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -374,10 +391,7 @@ class OpFromGraph(Op, HasInnerGraph):
assert not update_expr assert not update_expr
assert not shared_inputs assert not shared_inputs
self._inner_inputs = local_inputs self.fgraph = FunctionGraph(local_inputs, local_outputs, clone=False)
self._inner_outputs = local_outputs
self.inputs = inputs
self.outputs = outputs
self.kwargs = kwargs self.kwargs = kwargs
self.input_types = [inp.type for inp in inputs] self.input_types = [inp.type for inp in inputs]
self.output_types = [out.type for out in outputs] self.output_types = [out.type for out in outputs]
...@@ -778,29 +792,23 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -778,29 +792,23 @@ class OpFromGraph(Op, HasInnerGraph):
# The shared variables are not equal to the original shared # The shared variables are not equal to the original shared
# variables, so we construct a new `Op` that uses the new shared # variables, so we construct a new `Op` that uses the new shared
# variables instead. # variables instead.
# All this is really doing is making the unused (internally, at replace = dict(
# least) `self.outputs` and `self.shared_inputs` consistent. zip(self.inner_inputs[num_expected_inps:], new_shared_inputs)
# We could just as easily `copy` this `Op`, update )
# `self.shared_inputs`, and avoid cloning anything, but this is a
# more "change-proof" approach, because it still work when/if those
# attributes end up being used.
replace = dict(inner_and_input_shareds)
# If the new shared variables are inconsistent with the inner-graph, # If the new shared variables are inconsistent with the inner-graph,
# such errors should arise in this step # such errors should arise in this step
new_inner_outputs = clone_replace( new_inner_outputs = clone_replace(
self.outputs, replace=replace, share_inputs=True self.inner_outputs, replace=replace, share_inputs=True
) )
# `self.inputs` should not contain any shared variables, so we know # It's possible that the new shared variable inputs aren't actually
# that those are inputs to `new_outputs`, because we chose not to # shared variables. When they aren't we need to add them as new
# clone inputs; however, it's possible that the new shared variable # inputs.
# inputs aren't actually shared variables. When they aren't we
# need to add them as new inputs.
unshared_inputs = [ unshared_inputs = [
inp for inp in new_shared_inputs if not isinstance(inp, SharedVariable) inp for inp in new_shared_inputs if not isinstance(inp, SharedVariable)
] ]
new_inner_inputs = self.inputs + unshared_inputs new_inner_inputs = self.inner_inputs[:num_expected_inps] + unshared_inputs
new_op = type(self)( new_op = type(self)(
inputs=new_inner_inputs, inputs=new_inner_inputs,
...@@ -901,11 +909,11 @@ class OpFromGraph(Op, HasInnerGraph): ...@@ -901,11 +909,11 @@ class OpFromGraph(Op, HasInnerGraph):
@property @property
def inner_inputs(self): def inner_inputs(self):
return self._inner_inputs return self.fgraph.inputs
@property @property
def inner_outputs(self): def inner_outputs(self):
return self._inner_outputs return self.fgraph.outputs
def perform(self, node, inputs, outputs): def perform(self, node, inputs, outputs):
variables = self.fn(*inputs) variables = self.fn(*inputs)
......
...@@ -13,7 +13,7 @@ import aesara ...@@ -13,7 +13,7 @@ import aesara
from aesara.compile.ops import ViewOp from aesara.compile.ops import ViewOp
from aesara.configdefaults import config from aesara.configdefaults import config
from aesara.graph import utils from aesara.graph import utils
from aesara.graph.basic import Variable from aesara.graph.basic import NominalVariable, Variable
from aesara.graph.null_type import NullType, null_type from aesara.graph.null_type import NullType, null_type
from aesara.graph.op import get_test_values from aesara.graph.op import get_test_values
from aesara.graph.type import Type from aesara.graph.type import Type
...@@ -1295,15 +1295,16 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None): ...@@ -1295,15 +1295,16 @@ def _populate_grad_dict(var_to_app_to_idx, grad_dict, wrt, cost_name=None):
# has the right shape # has the right shape
if hasattr(term, "shape"): if hasattr(term, "shape"):
orig_ipt = inputs[i] orig_ipt = inputs[i]
for orig_ipt_v, term_v in get_test_values(orig_ipt, term): if not isinstance(orig_ipt, NominalVariable):
i_shape = orig_ipt_v.shape for orig_ipt_v, term_v in get_test_values(orig_ipt, term):
t_shape = term_v.shape i_shape = orig_ipt_v.shape
if i_shape != t_shape: t_shape = term_v.shape
raise ValueError( if i_shape != t_shape:
f"{node.op}.grad returned object of " raise ValueError(
f"shape {t_shape} as gradient term on input {int(i)} " f"{node.op}.grad returned object of "
f"of shape {i_shape}" f"shape {t_shape} as gradient term on input {int(i)} "
) f"of shape {i_shape}"
)
if not isinstance(term.type, (NullType, DisconnectedType)): if not isinstance(term.type, (NullType, DisconnectedType)):
if term.type.dtype not in aesara.tensor.type.float_dtypes: if term.type.dtype not in aesara.tensor.type.float_dtypes:
......
...@@ -1755,3 +1755,38 @@ def get_var_by_name( ...@@ -1755,3 +1755,38 @@ def get_var_by_name(
results += (var,) results += (var,)
return results return results
def replace_nominals_with_dummies(inputs, outputs):
"""Replace nominal inputs with dummy variables.
When constructing a new graph with nominal inputs from an existing graph,
pre-existing nominal inputs need to be replaced with dummy variables
beforehand; otherwise, sequential ID ordering (i.e. when nominals are IDed
based on the ordered inputs to which they correspond) of the nominals could
be broken, and/or circular replacements could manifest.
FYI: This function assumes that all the nominal variables in the subgraphs
between `inputs` and `outputs` are present in `inputs`.
"""
existing_nominal_replacements = {
i: i.type() for i in inputs if isinstance(i, NominalVariable)
}
if existing_nominal_replacements:
# Replace existing nominal variables, because we need to produce an
# inner-graph for which the nominal variable IDs correspond exactly
# to their input order
_ = clone_get_equiv(
inputs,
outputs,
copy_inputs=False,
copy_orphans=False,
memo=existing_nominal_replacements,
)
outputs = [existing_nominal_replacements[o] for o in outputs]
inputs = [existing_nominal_replacements[i] for i in inputs]
return inputs, outputs
...@@ -615,6 +615,9 @@ class _NoPythonOp(Op): ...@@ -615,6 +615,9 @@ class _NoPythonOp(Op):
class HasInnerGraph: class HasInnerGraph:
r"""A mixin for an `Op` that contain an inner graph.""" r"""A mixin for an `Op` that contain an inner graph."""
fgraph: "FunctionGraph"
"""A `FunctionGraph` of the inner function."""
@property @property
@abstractmethod @abstractmethod
def fn(self) -> "Function": def fn(self) -> "Function":
......
...@@ -375,6 +375,7 @@ N.B.: ...@@ -375,6 +375,7 @@ N.B.:
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map, print_destroy_map=print_destroy_map,
print_view_map=print_view_map, print_view_map=print_view_map,
inner_graph_node=s.owner,
) )
if file is _file: if file is _file:
...@@ -407,6 +408,7 @@ def _debugprint( ...@@ -407,6 +408,7 @@ def _debugprint(
op_information: Optional[Dict[Apply, Dict[Variable, str]]] = None, op_information: Optional[Dict[Apply, Dict[Variable, str]]] = None,
parent_node: Optional[Apply] = None, parent_node: Optional[Apply] = None,
print_op_info: bool = False, print_op_info: bool = False,
inner_graph_node: Optional[Apply] = None,
) -> IOBase: ) -> IOBase:
r"""Print the graph leading to `r`. r"""Print the graph leading to `r`.
...@@ -459,6 +461,8 @@ def _debugprint( ...@@ -459,6 +461,8 @@ def _debugprint(
print_op_info print_op_info
Print extra information provided by the relevant `Op`\s. For example, Print extra information provided by the relevant `Op`\s. For example,
print the tap information for `Scan` inputs and outputs. print the tap information for `Scan` inputs and outputs.
inner_graph_node
The inner-graph node in which `r` is contained.
""" """
if depth == 0: if depth == 0:
return file return file
...@@ -615,6 +619,7 @@ def _debugprint( ...@@ -615,6 +619,7 @@ def _debugprint(
print_op_info=print_op_info, print_op_info=print_op_info,
print_destroy_map=print_destroy_map, print_destroy_map=print_destroy_map,
print_view_map=print_view_map, print_view_map=print_view_map,
inner_graph_node=inner_graph_node,
) )
else: else:
...@@ -644,14 +649,9 @@ def _debugprint( ...@@ -644,14 +649,9 @@ def _debugprint(
var_output = f"{var_output} -> {outer_id_str}" var_output = f"{var_output} -> {outer_id_str}"
# This is an inner-graph input, so we need to find the outer node node_info = op_information.get(inner_graph_node)
# it belongs to and get the extra information from that if node_info and r in node_info:
for inner_graph in inner_graph_ops: var_output = f"{var_output} ({node_info[r]})"
if outer_r in inner_graph.owner.inputs:
node_info = op_information.get(inner_graph.owner)
if node_info and r in node_info:
var_output = f"{var_output} ({node_info[r]})"
break
node_info = op_information.get(parent_node) or op_information.get(r.owner) node_info = op_information.get(parent_node) or op_information.get(r.owner)
if node_info and r in node_info: if node_info and r in node_info:
......
...@@ -54,6 +54,7 @@ import numpy as np ...@@ -54,6 +54,7 @@ import numpy as np
import aesara import aesara
from aesara import tensor as at from aesara import tensor as at
from aesara.compile import SharedVariable
from aesara.compile.builders import infer_shape from aesara.compile.builders import infer_shape
from aesara.compile.function import function from aesara.compile.function import function
from aesara.compile.io import In, Out from aesara.compile.io import In, Out
...@@ -64,13 +65,16 @@ from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefine ...@@ -64,13 +65,16 @@ from aesara.gradient import DisconnectedType, NullType, Rop, grad, grad_undefine
from aesara.graph.basic import ( from aesara.graph.basic import (
Apply, Apply,
Constant, Constant,
NominalVariable,
Variable, Variable,
clone_replace, clone_replace,
equal_computations, equal_computations,
graph_inputs, graph_inputs,
io_connection_pattern, io_connection_pattern,
replace_nominals_with_dummies,
) )
from aesara.graph.features import NoOutputFromInplace from aesara.graph.features import NoOutputFromInplace
from aesara.graph.fg import FunctionGraph
from aesara.graph.op import HasInnerGraph, Op from aesara.graph.op import HasInnerGraph, Op
from aesara.graph.utils import MissingInputError from aesara.graph.utils import MissingInputError
from aesara.link.c.basic import CLinker from aesara.link.c.basic import CLinker
...@@ -757,8 +761,27 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -757,8 +761,27 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
If ``True``, all the shared variables used in the inner-graph must be provided. If ``True``, all the shared variables used in the inner-graph must be provided.
""" """
self.inputs = inputs inputs, outputs = replace_nominals_with_dummies(inputs, outputs)
self.outputs = outputs
input_replacements = []
for n, v in enumerate(inputs):
if not isinstance(v, (SharedVariable, Constant)):
input_replacements.append((v, NominalVariable(n, v.type)))
assert not isinstance(v, NominalVariable)
outputs = clone_replace(outputs, replace=input_replacements)
if input_replacements:
_, inputs_ = zip(*input_replacements)
inputs = list(inputs_)
else:
inputs = []
self.fgraph = FunctionGraph(inputs, outputs, clone=False)
self.inputs = self.fgraph.inputs
self.outputs = self.fgraph.outputs
self.info = info self.info = info
self.truncate_gradient = truncate_gradient self.truncate_gradient = truncate_gradient
self.name = name self.name = name
...@@ -3416,8 +3439,8 @@ def _op_debug_information_Scan(op, node): ...@@ -3416,8 +3439,8 @@ def _op_debug_information_Scan(op, node):
inner_inputs = inner_fn.maker.fgraph.inputs inner_inputs = inner_fn.maker.fgraph.inputs
inner_outputs = inner_fn.maker.fgraph.outputs inner_outputs = inner_fn.maker.fgraph.outputs
else: else:
inner_inputs = op.inputs inner_inputs = op.inner_inputs
inner_outputs = op.outputs inner_outputs = op.inner_outputs
scan_args = ScanArgs( scan_args = ScanArgs(
node.inputs, node.inputs,
......
...@@ -466,7 +466,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -466,7 +466,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
y = shared(1.0, name="y") y = shared(1.0, name="y")
test_ofg = OpFromGraph([x], [x + y], on_unused_input="ignore") test_ofg = OpFromGraph([x], [x + y], on_unused_input="ignore")
assert test_ofg.inputs == [x]
assert test_ofg.shared_inputs == [y] assert test_ofg.shared_inputs == [y]
out = test_ofg(x) out = test_ofg(x)
...@@ -478,7 +477,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -478,7 +477,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0] out_new = test_ofg.make_node(*(out.owner.inputs[:1] + [y_clone])).outputs[0]
assert "on_unused_input" in out_new.owner.op.kwargs assert "on_unused_input" in out_new.owner.op.kwargs
assert out_new.owner.op.inputs == [x]
assert out_new.owner.op.shared_inputs == [y_clone] assert out_new.owner.op.shared_inputs == [y_clone]
out_fn = function([x], out_new) out_fn = function([x], out_new)
...@@ -497,7 +495,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -497,7 +495,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
y = shared(1.0, name="y") y = shared(1.0, name="y")
test_ofg = OpFromGraph([x], [x + y]) test_ofg = OpFromGraph([x], [x + y])
assert test_ofg.inputs == [x]
assert test_ofg.shared_inputs == [y] assert test_ofg.shared_inputs == [y]
out = test_ofg(at.as_tensor(1.0, dtype=config.floatX)) out = test_ofg(at.as_tensor(1.0, dtype=config.floatX))
...@@ -517,7 +514,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -517,7 +514,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
y = shared(1.0, name="y") y = shared(1.0, name="y")
test_ofg = OpFromGraph([], [y]) test_ofg = OpFromGraph([], [y])
assert test_ofg.inputs == []
assert test_ofg.shared_inputs == [y] assert test_ofg.shared_inputs == [y]
out_1_fn = function([], test_ofg()) out_1_fn = function([], test_ofg())
...@@ -526,7 +522,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -526,7 +522,6 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert np.array_equal(res_1, 1.0) assert np.array_equal(res_1, 1.0)
test_ofg_new = test_ofg.make_node(x) test_ofg_new = test_ofg.make_node(x)
assert test_ofg_new.op.inputs == [x]
assert test_ofg_new.op.shared_inputs == [] assert test_ofg_new.op.shared_inputs == []
out_2_fn = function([x], test_ofg_new.outputs[0]) out_2_fn = function([x], test_ofg_new.outputs[0])
...@@ -535,6 +530,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester): ...@@ -535,6 +530,7 @@ class TestOpFromGraph(unittest_tools.InferShapeTester):
assert np.array_equal(res_2, 1.0) assert np.array_equal(res_2, 1.0)
@config.change_flags(floatX="float64")
def test_debugprint(): def test_debugprint():
x, y, z = matrices("xyz") x, y, z = matrices("xyz")
e = x + y * z e = x + y * z
...@@ -553,10 +549,10 @@ Inner graphs: ...@@ -553,10 +549,10 @@ Inner graphs:
OpFromGraph{inline=False} [id A] OpFromGraph{inline=False} [id A]
>Elemwise{add,no_inplace} [id E] >Elemwise{add,no_inplace} [id E]
> |x [id F] > |*0-<TensorType(float64, (None, None))> [id F]
> |Elemwise{mul,no_inplace} [id G] > |Elemwise{mul,no_inplace} [id G]
> |y [id H] > |*1-<TensorType(float64, (None, None))> [id H]
> |z [id I] > |*2-<TensorType(float64, (None, None))> [id I]
""" """
for truth, out in zip(exp_res.split("\n"), lines): for truth, out in zip(exp_res.split("\n"), lines):
......
...@@ -2355,9 +2355,11 @@ def test_compute_test_values(): ...@@ -2355,9 +2355,11 @@ def test_compute_test_values():
assert np.array_equal(z_grad.tag.test_value, np.r_[9.0, 9.0, 9.0]) assert np.array_equal(z_grad.tag.test_value, np.r_[9.0, 9.0, 9.0])
@pytest.mark.xfail(reason="NominalVariables don't support test values")
def test_compute_test_value_grad(): def test_compute_test_value_grad():
# Test case originally reported by Bitton Tenessi """
# https://groups.google.com/d/msg/theano-users/fAP3i2CbskQ/3OgBf4yjqiQJ See https://groups.google.com/d/msg/theano-users/fAP3i2CbskQ/3OgBf4yjqiQJ
"""
WEIGHT = np.array([1, 2, 1, 3, 4, 1, 5, 6, 1, 7, 8, 1], dtype="float32") WEIGHT = np.array([1, 2, 1, 3, 4, 1, 5, 6, 1, 7, 8, 1], dtype="float32")
with config.change_flags(compute_test_value="raise", exception_verbosity="high"): with config.change_flags(compute_test_value="raise", exception_verbosity="high"):
...@@ -2395,10 +2397,12 @@ def test_compute_test_value_grad(): ...@@ -2395,10 +2397,12 @@ def test_compute_test_value_grad():
grad(loss, W_flat) grad(loss, W_flat)
@pytest.mark.xfail(reason="NominalVariables don't support test values")
def test_compute_test_value_grad_cast(): def test_compute_test_value_grad_cast():
# Test for test values when variables have to be casted """Test for test values when variables have to be casted.
# Reported by Daniel Renshaw at
# https://groups.google.com/d/topic/theano-users/o4jK9xDe5WI/discussion See https://groups.google.com/d/topic/theano-users/o4jK9xDe5WI/discussion
"""
with config.change_flags(compute_test_value="raise"): with config.change_flags(compute_test_value="raise"):
h = matrix("h") h = matrix("h")
h.tag.test_value = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=config.floatX) h.tag.test_value = np.array([[1, 2, 3, 4], [5, 6, 7, 8]], dtype=config.floatX)
...@@ -2434,7 +2438,7 @@ def test_constant_folding_n_steps(): ...@@ -2434,7 +2438,7 @@ def test_constant_folding_n_steps():
def test_outputs_taps_check(): def test_outputs_taps_check():
# Checks that errors are raised with bad output_info taps. """Checks that errors are raised with bad output_info taps."""
x = fvector("x") x = fvector("x")
y = fvector("y") y = fvector("y")
...@@ -2462,7 +2466,6 @@ def test_inconsistent_broadcast_error(): ...@@ -2462,7 +2466,6 @@ def test_inconsistent_broadcast_error():
grad(y.sum(), x) grad(y.sum(), x)
@pytest.mark.xfail(raises=MissingInputError)
def test_missing_input_error(): def test_missing_input_error():
c = shared(0.0) c = shared(0.0)
inc = scalar("inc") inc = scalar("inc")
...@@ -2470,8 +2473,8 @@ def test_missing_input_error(): ...@@ -2470,8 +2473,8 @@ def test_missing_input_error():
def count_up(): def count_up():
return at.zeros(()), {c: c + inc} return at.zeros(()), {c: c + inc}
_, updates = scan(count_up, n_steps=20) with pytest.raises(MissingInputError):
function(inputs=[inc], outputs=[], updates=updates) _, updates = scan(count_up, n_steps=20)
class TestGradUntil: class TestGradUntil:
......
...@@ -3,10 +3,12 @@ import pytest ...@@ -3,10 +3,12 @@ import pytest
import aesara import aesara
import aesara.tensor as at import aesara.tensor as at
from aesara.configdefaults import config
from aesara.printing import debugprint, pydot_imported, pydotprint from aesara.printing import debugprint, pydot_imported, pydotprint
from aesara.tensor.type import dvector, iscalar, scalar, vector from aesara.tensor.type import dvector, iscalar, scalar, vector
@config.change_flags(floatX="float64")
def test_debugprint_sitsot(): def test_debugprint_sitsot():
k = iscalar("k") k = iscalar("k")
A = dvector("A") A = dvector("A")
...@@ -55,8 +57,8 @@ def test_debugprint_sitsot(): ...@@ -55,8 +57,8 @@ def test_debugprint_sitsot():
for{cpu,scan_fn} [id C] (outer_out_sit_sot-0) for{cpu,scan_fn} [id C] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id W] (inner_out_sit_sot-0) >Elemwise{mul,no_inplace} [id W] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id X] -> [id E] (inner_in_sit_sot-0) > |*0-<TensorType(float64, (None,))> [id X] -> [id E] (inner_in_sit_sot-0)
> |A_copy [id Y] -> [id M] (inner_in_non_seqs-0)""" > |*1-<TensorType(float64, (None,))> [id Y] -> [id M] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
...@@ -110,13 +112,14 @@ def test_debugprint_sitsot_no_extra_info(): ...@@ -110,13 +112,14 @@ def test_debugprint_sitsot_no_extra_info():
for{cpu,scan_fn} [id C] for{cpu,scan_fn} [id C]
>Elemwise{mul,no_inplace} [id W] >Elemwise{mul,no_inplace} [id W]
> |<TensorType(float64, (None,))> [id X] -> [id E] > |*0-<TensorType(float64, (None,))> [id X] -> [id E]
> |A_copy [id Y] -> [id M]""" > |*1-<TensorType(float64, (None,))> [id Y] -> [id M]"""
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
@config.change_flags(floatX="float64")
def test_debugprint_nitsot(): def test_debugprint_nitsot():
coefficients = vector("coefficients") coefficients = vector("coefficients")
x = scalar("x") x = scalar("x")
...@@ -170,15 +173,16 @@ def test_debugprint_nitsot(): ...@@ -170,15 +173,16 @@ def test_debugprint_nitsot():
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0) for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
>Elemwise{mul,no_inplace} [id X] (inner_out_nit_sot-0) >Elemwise{mul,no_inplace} [id X] (inner_out_nit_sot-0)
> |coefficients[t] [id Y] -> [id S] (inner_in_seqs-0) > |*0-<TensorType(float64, ())> [id Y] -> [id S] (inner_in_seqs-0)
> |Elemwise{pow,no_inplace} [id Z] > |Elemwise{pow,no_inplace} [id Z]
> |x_copy [id BA] -> [id W] (inner_in_non_seqs-0) > |*2-<TensorType(float64, ())> [id BA] -> [id W] (inner_in_non_seqs-0)
> |<TensorType(int64, ())> [id BB] -> [id U] (inner_in_seqs-1)""" > |*1-<TensorType(int64, ())> [id BB] -> [id U] (inner_in_seqs-1)"""
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
@config.change_flags(floatX="float64")
def test_debugprint_nested_scans(): def test_debugprint_nested_scans():
coefficients = dvector("coefficients") coefficients = dvector("coefficients")
max_coefficients_supported = 10 max_coefficients_supported = 10
...@@ -251,22 +255,22 @@ def test_debugprint_nested_scans(): ...@@ -251,22 +255,22 @@ def test_debugprint_nested_scans():
for{cpu,scan_fn} [id B] (outer_out_nit_sot-0) for{cpu,scan_fn} [id B] (outer_out_nit_sot-0)
>Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0) >Elemwise{mul,no_inplace} [id Y] (inner_out_nit_sot-0)
> |InplaceDimShuffle{x} [id Z] > |InplaceDimShuffle{x} [id Z]
> | |coefficients[t] [id BA] -> [id S] (inner_in_seqs-0) > | |*0-<TensorType(float64, ())> [id BA] -> [id S] (inner_in_seqs-0)
> |Elemwise{pow,no_inplace} [id BB] > |Elemwise{pow,no_inplace} [id BB]
> |Subtensor{int64} [id BC] > |Subtensor{int64} [id BC]
> | |Subtensor{int64::} [id BD] > | |Subtensor{int64::} [id BD]
> | | |for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0) > | | |for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
> | | | |k_copy [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps) > | | | |*3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1) (n_steps)
> | | | |IncSubtensor{Set;:int64:} [id BG] (outer_in_sit_sot-0) > | | | |IncSubtensor{Set;:int64:} [id BG] (outer_in_sit_sot-0)
> | | | | |AllocEmpty{dtype='float64'} [id BH] > | | | | |AllocEmpty{dtype='float64'} [id BH]
> | | | | | |Elemwise{add,no_inplace} [id BI] > | | | | | |Elemwise{add,no_inplace} [id BI]
> | | | | | | |k_copy [id BF] -> [id X] (inner_in_non_seqs-1) > | | | | | | |*3-<TensorType(int32, ())> [id BF] -> [id X] (inner_in_non_seqs-1)
> | | | | | | |Subtensor{int64} [id BJ] > | | | | | | |Subtensor{int64} [id BJ]
> | | | | | | |Shape [id BK] > | | | | | | |Shape [id BK]
> | | | | | | | |Rebroadcast{(0, False)} [id BL] > | | | | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | | | | |InplaceDimShuffle{x,0} [id BM] > | | | | | | | |InplaceDimShuffle{x,0} [id BM]
> | | | | | | | |Elemwise{second,no_inplace} [id BN] > | | | | | | | |Elemwise{second,no_inplace} [id BN]
> | | | | | | | |A_copy [id BO] -> [id W] (inner_in_non_seqs-0) > | | | | | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0)
> | | | | | | | |InplaceDimShuffle{x} [id BP] > | | | | | | | |InplaceDimShuffle{x} [id BP]
> | | | | | | | |TensorConstant{1.0} [id BQ] > | | | | | | | |TensorConstant{1.0} [id BQ]
> | | | | | | |ScalarConstant{0} [id BR] > | | | | | | |ScalarConstant{0} [id BR]
...@@ -277,21 +281,22 @@ def test_debugprint_nested_scans(): ...@@ -277,21 +281,22 @@ def test_debugprint_nested_scans():
> | | | | |Rebroadcast{(0, False)} [id BL] > | | | | |Rebroadcast{(0, False)} [id BL]
> | | | | |ScalarFromTensor [id BV] > | | | | |ScalarFromTensor [id BV]
> | | | | |Subtensor{int64} [id BJ] > | | | | |Subtensor{int64} [id BJ]
> | | | |A_copy [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0) > | | | |*2-<TensorType(float64, (None,))> [id BO] -> [id W] (inner_in_non_seqs-0) (outer_in_non_seqs-0)
> | | |ScalarConstant{1} [id BW] > | | |ScalarConstant{1} [id BW]
> | |ScalarConstant{-1} [id BX] > | |ScalarConstant{-1} [id BX]
> |InplaceDimShuffle{x} [id BY] > |InplaceDimShuffle{x} [id BY]
> |<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1) > |*1-<TensorType(int64, ())> [id BZ] -> [id U] (inner_in_seqs-1)
for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0) for{cpu,scan_fn} [id BE] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CA] (inner_out_sit_sot-0) >Elemwise{mul,no_inplace} [id CA] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id CB] -> [id BG] (inner_in_sit_sot-0) > |*0-<TensorType(float64, (None,))> [id CB] -> [id BG] (inner_in_sit_sot-0)
> |A_copy [id CC] -> [id BO] (inner_in_non_seqs-0)""" > |*1-<TensorType(float64, (None,))> [id CC] -> [id BO] (inner_in_non_seqs-0)"""
for truth, out in zip(expected_output.split("\n"), lines): for truth, out in zip(expected_output.split("\n"), lines):
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
@config.change_flags(floatX="float64")
def test_debugprint_mitsot(): def test_debugprint_mitsot():
def fn(a_m2, a_m1, b_m2, b_m1): def fn(a_m2, a_m1, b_m2, b_m1):
return a_m1 + a_m2, b_m1 + b_m2 return a_m1 + a_m2, b_m1 + b_m2
...@@ -351,11 +356,11 @@ def test_debugprint_mitsot(): ...@@ -351,11 +356,11 @@ def test_debugprint_mitsot():
for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0) for{cpu,scan_fn}.0 [id C] (outer_out_mit_sot-0)
>Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0) >Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
> |<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1) > |*1-<TensorType(int64, ())> [id BC] -> [id E] (inner_in_mit_sot-0-1)
> |<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0) > |*0-<TensorType(int64, ())> [id BD] -> [id E] (inner_in_mit_sot-0-0)
>Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1) >Elemwise{add,no_inplace} [id BE] (inner_out_mit_sot-1)
> |<TensorType(int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1) > |*3-<TensorType(int64, ())> [id BF] -> [id O] (inner_in_mit_sot-1-1)
> |<TensorType(int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0) > |*2-<TensorType(int64, ())> [id BG] -> [id O] (inner_in_mit_sot-1-0)
for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1) for{cpu,scan_fn}.1 [id C] (outer_out_mit_sot-1)
>Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0) >Elemwise{add,no_inplace} [id BB] (inner_out_mit_sot-0)
...@@ -365,6 +370,7 @@ def test_debugprint_mitsot(): ...@@ -365,6 +370,7 @@ def test_debugprint_mitsot():
assert truth.strip() == out.strip() assert truth.strip() == out.strip()
@config.change_flags(floatX="float64")
def test_debugprint_mitmot(): def test_debugprint_mitmot():
k = iscalar("k") k = iscalar("k")
...@@ -471,19 +477,19 @@ def test_debugprint_mitmot(): ...@@ -471,19 +477,19 @@ def test_debugprint_mitmot():
for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0) for{cpu,grad_of_scan_fn}.1 [id B] (outer_out_sit_sot-0)
>Elemwise{add,no_inplace} [id CM] (inner_out_mit_mot-0-0) >Elemwise{add,no_inplace} [id CM] (inner_out_mit_mot-0-0)
> |Elemwise{mul} [id CN] > |Elemwise{mul} [id CN]
> | |<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0) > | |*2-<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |A_copy [id CP] -> [id P] (inner_in_non_seqs-0) > | |*5-<TensorType(float64, (None,))> [id CP] -> [id P] (inner_in_non_seqs-0)
> |<TensorType(float64, (None,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1) > |*3-<TensorType(float64, (None,))> [id CQ] -> [id BL] (inner_in_mit_mot-0-1)
>Elemwise{add,no_inplace} [id CR] (inner_out_sit_sot-0) >Elemwise{add,no_inplace} [id CR] (inner_out_sit_sot-0)
> |Elemwise{mul} [id CS] > |Elemwise{mul} [id CS]
> | |<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0) > | |*2-<TensorType(float64, (None,))> [id CO] -> [id BL] (inner_in_mit_mot-0-0)
> | |<TensorType(float64, (None,))> [id CT] -> [id Z] (inner_in_seqs-0) > | |*0-<TensorType(float64, (None,))> [id CT] -> [id Z] (inner_in_seqs-0)
> |<TensorType(float64, (None,))> [id CU] -> [id CE] (inner_in_sit_sot-0) > |*4-<TensorType(float64, (None,))> [id CU] -> [id CE] (inner_in_sit_sot-0)
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0) >Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
> |<TensorType(float64, (None,))> [id CT] -> [id H] (inner_in_sit_sot-0) > |*0-<TensorType(float64, (None,))> [id CT] -> [id H] (inner_in_sit_sot-0)
> |A_copy [id CP] -> [id P] (inner_in_non_seqs-0) > |*1-<TensorType(float64, (None,))> [id CW] -> [id P] (inner_in_non_seqs-0)
for{cpu,scan_fn} [id F] (outer_out_sit_sot-0) for{cpu,scan_fn} [id F] (outer_out_sit_sot-0)
>Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0) >Elemwise{mul,no_inplace} [id CV] (inner_out_sit_sot-0)
...@@ -540,11 +546,11 @@ def test_debugprint_compiled_fn(): ...@@ -540,11 +546,11 @@ def test_debugprint_compiled_fn():
>Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0) >Elemwise{Composite{Switch(LT(i0, i1), i2, i0)}} [id I] (inner_out_sit_sot-0)
> |TensorConstant{0} [id J] > |TensorConstant{0} [id J]
> |Subtensor{int64, int64, int64} [id K] > |Subtensor{int64, int64, int64} [id K]
> | |<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0) > | |*2-<TensorType(float64, (20000, 2, 2))> [id L] -> [id H] (inner_in_non_seqs-0)
> | |ScalarFromTensor [id M] > | |ScalarFromTensor [id M]
> | | |<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0) > | | |*0-<TensorType(int64, ())> [id N] -> [id C] (inner_in_seqs-0)
> | |ScalarFromTensor [id O] > | |ScalarFromTensor [id O]
> | | |<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0) > | | |*1-<TensorType(int64, ())> [id P] -> [id D] (inner_in_sit_sot-0)
> | |ScalarConstant{0} [id Q] > | |ScalarConstant{0} [id Q]
> |TensorConstant{1} [id R] > |TensorConstant{1} [id R]
""" """
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论