提交 a7840844 authored 作者: kc611's avatar kc611 提交者: Brandon T. Willard

Added ScanArgs

上级 45a2f784
......@@ -17,7 +17,7 @@ from aesara.link.utils import fgraph_to_python
from aesara.scalar import Softplus
from aesara.scalar.basic import Cast, Clip, Composite, Identity, ScalarOp, Second
from aesara.scan.op import Scan
from aesara.scan.utils import scan_args as ScanArgs
from aesara.scan.utils import ScanArgs
from aesara.tensor.basic import (
Alloc,
AllocDiag,
......
......@@ -80,11 +80,11 @@ from aesara.graph.opt import GlobalOptimizer, in2out, local_optimizer
from aesara.graph.optdb import EquilibriumDB, SequenceDB
from aesara.scan.op import Scan
from aesara.scan.utils import (
ScanArgs,
compress_outs,
expand_empty,
reconstruct_graph,
safe_new,
scan_args,
scan_can_remove_outs,
)
from aesara.tensor import basic_opt, math_opt
......@@ -751,9 +751,9 @@ class PushOutScanOutput(GlobalOptimizer):
op = node.op
# Use scan_args to parse the inputs and outputs of scan for ease of
# Use `ScanArgs` to parse the inputs and outputs of scan for ease of
# use
args = scan_args(node.inputs, node.outputs, op.inputs, op.outputs, op.info)
args = ScanArgs(node.inputs, node.outputs, op.inputs, op.outputs, op.info)
new_scan_node = None
clients = {}
......@@ -917,7 +917,7 @@ class PushOutScanOutput(GlobalOptimizer):
fgraph, old_scan_node, old_scan_args, add_as_nitsots
)
new_scan_args = scan_args(
new_scan_args = ScanArgs(
new_scan_node.inputs,
new_scan_node.outputs,
new_scan_node.op.inputs,
......@@ -944,13 +944,12 @@ class PushOutScanOutput(GlobalOptimizer):
old_scan_node.inputs[0] for i in range(nb_new_outs)
]
# Create the scan_args corresponding to the new scan op to
# create
# Create the `ScanArgs` corresponding to the new `Scan` `Op` to create
new_scan_args = copy.copy(old_scan_args)
new_scan_args.inner_out_nit_sot.extend(new_outputs_inner)
new_scan_args.outer_in_nit_sot.extend(new_nitsots_initial_value)
# Create the scan op from the scan_args
# Create the `Scan` `Op` from the `ScanArgs`
new_scan_op = Scan(
new_scan_args.inner_inputs, new_scan_args.inner_outputs, new_scan_args.info
)
......@@ -1970,7 +1969,7 @@ def scan_merge_inouts(fgraph, node):
# Do a first pass to merge identical external inputs.
# Equivalent inputs will be stored in inp_equiv, then a new
# scan node created without duplicates.
a = scan_args(
a = ScanArgs(
node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info
)
......@@ -2016,7 +2015,7 @@ def scan_merge_inouts(fgraph, node):
if not isinstance(outputs, (list, tuple)):
outputs = [outputs]
na = scan_args(outer_inputs, outputs, op.inputs, op.outputs, op.info)
na = ScanArgs(outer_inputs, outputs, op.inputs, op.outputs, op.info)
remove = [node]
else:
na = a
......
......@@ -16,7 +16,7 @@ __copyright__ = "(c) 2010, Universite de Montreal"
import copy
import logging
import warnings
from collections import OrderedDict
from collections import OrderedDict, namedtuple
import numpy as np
......@@ -872,11 +872,27 @@ def reconstruct_graph(inputs, outputs, tag=None):
return (nw_inputs, nw_outputs)
class scan_args:
"""
Parses the inputs and outputs of scan in an easy to manipulate format.
FieldInfo = namedtuple(
"FieldInfo", ("name", "agg_name", "index", "inner_index", "agg_index")
)
"""
def safe_index(lst, x):
try:
return lst.index(x)
except ValueError:
return None
def default_filter_scanargs(x):
return x.startswith("inner_") or x.startswith("outer_")
class ScanArgs:
"""Parses the inputs and outputs of `Scan` in an easy to manipulate format."""
default_filter = default_filter_scanargs
nested_list_fields = ("inner_in_mit_mot", "inner_in_mit_sot", "inner_out_mit_mot")
def __init__(
self, outer_inputs, outer_outputs, _inner_inputs, _inner_outputs, info
......@@ -1002,6 +1018,80 @@ class scan_args:
if k in info:
self.other_info[k] = info[k]
@staticmethod
def from_node(node):
from aesara.scan.op import Scan
if not isinstance(node.op, Scan):
raise TypeError("{} is not a Scan node".format(node))
return ScanArgs(
node.inputs, node.outputs, node.op.inputs, node.op.outputs, node.op.info
)
@classmethod
def create_empty(cls):
info = OrderedDict(
[
("n_seqs", 0),
("n_mit_mot", 0),
("n_mit_sot", 0),
("tap_array", []),
("n_sit_sot", 0),
("n_nit_sot", 0),
("n_shared_outs", 0),
("n_mit_mot_outs", 0),
("mit_mot_out_slices", []),
("truncate_gradient", -1),
("name", None),
("mode", None),
("destroy_map", OrderedDict()),
("gpua", False),
("as_while", False),
("profile", False),
("allow_gc", False),
]
)
res = cls([1], [], [], [], info)
res.n_steps = None
return res
@property
def n_nit_sot(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return self.info["n_nit_sot"]
@property
def inputs(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return self.inner_inputs
@property
def n_mit_mot(self):
# This is just a hack that allows us to use `Scan.get_oinp_iinp_iout_oout_mappings`
return self.info["n_mit_mot"]
@property
def var_mappings(self):
from aesara.scan.op import Scan
return Scan.get_oinp_iinp_iout_oout_mappings(self)
@property
def field_names(self):
res = ["mit_mot_out_slices", "mit_mot_in_slices", "mit_sot_in_slices"]
res.extend(
[
attr
for attr in self.__dict__
if attr.startswith("inner_in")
or attr.startswith("inner_out")
or attr.startswith("outer_in")
or attr.startswith("outer_out")
or attr == "n_steps"
]
)
return res
@property
def inner_inputs(self):
return (
......@@ -1066,6 +1156,179 @@ class scan_args:
**self.other_info,
)
def get_alt_field(self, var_info, alt_prefix):
"""Get the alternate input/output field for a given element of `ScanArgs`.
For example, if `var_info` is in `ScanArgs.outer_out_sit_sot`, then
`get_alt_field(var_info, "inner_out")` returns the element corresponding
`var_info` in `ScanArgs.inner_out_sit_sot`.
Parameters
----------
var_info: TensorVariable or FieldInfo
The element for which we want the alternate
alt_prefix: str
The string prefix for the alternate field type. It can be one of
the following: "inner_out", "inner_in", "outer_in", and "outer_out".
Outputs
-------
TensorVariable
Returns the alternate variable.
"""
if not isinstance(var_info, FieldInfo):
var_info = self.find_among_fields(var_info)
alt_type = var_info.name[(var_info.name.index("_", 6) + 1) :]
alt_var = getattr(self, "inner_out_{}".format(alt_type))[var_info.index]
return alt_var
def find_among_fields(self, i, field_filter=default_filter):
"""Find the type and indices of the field containing a given element.
NOTE: This only returns the *first* field containing the given element.
Parameters
----------
i: theano.gof.graph.Variable
The element to find among this object's fields.
field_filter: function
A function passed to `filter` that determines which fields to
consider. It must take a string field name and return a truthy
value.
Returns
-------
A tuple of length 4 containing the field name string, the first index,
the second index (for nested lists), and the "major" index (i.e. the
index within the aggregate lists like `self.inner_inputs`,
`self.outer_outputs`, etc.), or a triple of `None` when no match is
found.
"""
field_names = filter(field_filter, self.field_names)
for field_name in field_names:
lst = getattr(self, field_name)
field_prefix = field_name[:8]
if field_prefix.endswith("in"):
agg_field_name = "{}puts".format(field_prefix)
else:
agg_field_name = "{}tputs".format(field_prefix)
agg_list = getattr(self, agg_field_name)
if field_name in self.nested_list_fields:
for n, sub_lst in enumerate(lst):
idx = safe_index(sub_lst, i)
if idx is not None:
agg_idx = safe_index(agg_list, i)
return FieldInfo(field_name, agg_field_name, n, idx, agg_idx)
else:
idx = safe_index(lst, i)
if idx is not None:
agg_idx = safe_index(agg_list, i)
return FieldInfo(field_name, agg_field_name, idx, None, agg_idx)
return None
def _remove_from_fields(self, i, field_filter=default_filter):
field_info = self.find_among_fields(i, field_filter=field_filter)
if field_info is None:
return None
if field_info.inner_index is not None:
getattr(self, field_info.name)[field_info.index].remove(i)
else:
getattr(self, field_info.name).remove(i)
return field_info
def get_dependent_nodes(self, i, seen=None):
from aesara.graph import inputs as at_inputs
if seen is None:
seen = {i}
else:
seen.add(i)
var_mappings = self.var_mappings
field_info = self.find_among_fields(i)
if field_info is None:
raise ValueError("{} not found among fields.".format(i))
# Find the `var_mappings` key suffix that matches the field/set of
# arguments containing our source node
if field_info.name[:8].endswith("_in"):
map_key_suffix = "{}p".format(field_info.name[:8])
else:
map_key_suffix = field_info.name[:9]
dependent_nodes = set()
for k, v in var_mappings.items():
if not k.endswith(map_key_suffix):
continue
dependent_idx = v[field_info.agg_index]
dependent_idx = (
dependent_idx if isinstance(dependent_idx, list) else [dependent_idx]
)
# Get the `ScanArgs` field name for the aggregate list property
# corresponding to these dependent argument types (i.e. either
# "outer_inputs", "inner_inputs", "inner_outputs", or
# "outer_outputs").
# To do this, we need to parse the "shared" prefix of the
# current `var_mappings` key and append the missing parts so that
# it either forms `"*_inputs"` or `"*_outputs"`.
to_agg_field_prefix = k[:9]
if to_agg_field_prefix.endswith("p"):
to_agg_field_name = "{}uts".format(to_agg_field_prefix)
else:
to_agg_field_name = "{}puts".format(to_agg_field_prefix)
to_agg_field = getattr(self, to_agg_field_name)
for d_id in dependent_idx:
if d_id < 0:
continue
dependent_var = to_agg_field[d_id]
if dependent_var not in seen:
dependent_nodes.add(dependent_var)
if field_info.name.startswith("inner_in"):
# If starting from an inner-input, then we need to find any
# inner-outputs that depend on it.
for out_n in self.inner_outputs:
if i in at_inputs([out_n]):
if out_n not in seen:
dependent_nodes.add(out_n)
for n in tuple(dependent_nodes):
if n in seen:
continue
sub_dependent_nodes = self.get_dependent_nodes(n, seen=seen)
dependent_nodes |= sub_dependent_nodes
seen |= sub_dependent_nodes
return dependent_nodes
def remove_from_fields(self, i, rm_dependents=True):
if rm_dependents:
vars_to_remove = self.get_dependent_nodes(i) | {i}
else:
vars_to_remove = {i}
rm_info = []
for v in vars_to_remove:
dependent_rm_info = self._remove_from_fields(v)
rm_info.append((v, dependent_rm_info))
return rm_info
def __copy__(self):
res = object.__new__(type(self))
res.__dict__.update(self.__dict__)
......@@ -1101,6 +1364,52 @@ class scan_args:
getattr(res, attr).extend(getattr(other, attr))
return res
def __str__(self):
inner_arg_strs = [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("outer_in") or p == "n_steps"
]
inner_arg_strs += [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("inner_in")
]
inner_arg_strs += [
"\tmit_mot_in_slices={}".format(self.mit_mot_in_slices),
"\tmit_sot_in_slices={}".format(self.mit_sot_in_slices),
]
inner_arg_strs += [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("inner_out")
]
inner_arg_strs += [
"\tmit_mot_out_slices={}".format(self.mit_mot_out_slices),
]
inner_arg_strs += [
"\t{}={}".format(p, getattr(self, p))
for p in self.field_names
if p.startswith("outer_out")
]
res = "ScanArgs(\n{})".format(",\n".join(inner_arg_strs))
return res
def __repr__(self):
return self.__str__()
def __eq__(self, other):
if not isinstance(other, type(self)):
return NotImplemented
for field_name in self.field_names:
if not hasattr(other, field_name) or getattr(self, field_name) != getattr(
other, field_name
):
return False
return True
def forced_replace(out, x, y):
"""
......
......@@ -276,12 +276,12 @@ by usage.
Accessing/manipulating Scan's inputs and outputs by type
--------------------------------------------------------
Declared in ``utils.py``, the class ``scan_args`` handles the
Declared in ``utils.py``, the class ``ScanArgs`` handles the
parsing of the inputs and outputs (both inner and outer) to a format
that is easier to analyse and manipulate. Without this class,
analysing Scan's inputs and outputs often required convoluted logic
which make for code that is hard to read and to maintain. Because of
this, you should favor using ``scan_args`` when it is practical and
this, you should favor using ``ScanArgs`` when it is practical and
appropriate to do so.
The scan op also defines a few helper functions for this purpose, such as
......
......@@ -5,7 +5,7 @@ import pytest
import aesara
from aesara import tensor as aet
from aesara.scan.utils import map_variables
from aesara.scan.utils import ScanArgs, map_variables
from aesara.tensor.type import scalar, vector
......@@ -145,3 +145,15 @@ class TestMapVariables:
shared.update = shared + 1
with pytest.raises(NotImplementedError):
map_variables(self.replacer, [t])
def test_ScanArgs():
scan_args = ScanArgs.create_empty()
assert scan_args.n_steps is None
for name in scan_args.field_names:
if name == "n_steps":
continue
assert len(getattr(scan_args, name)) == 0
with pytest.raises(TypeError):
ScanArgs.from_node(aet.ones(2).owner)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论