提交 fd8323fe authored 作者: Brandon T. Willard's avatar Brandon T. Willard 提交者: Brandon T. Willard

Remove use of OrderedDict in aesara.scan.basic

上级 c70ef127
import logging
from collections import OrderedDict
import numpy as np
......@@ -351,11 +350,11 @@ def scan(
n_seqs = len(seqs)
n_outs = len(outs_info)
return_steps = OrderedDict()
return_steps = {}
# wrap sequences in a dictionary if they are not already dictionaries
for i in range(n_seqs):
if not isinstance(seqs[i], dict):
seqs[i] = OrderedDict([("input", seqs[i]), ("taps", [0])])
seqs[i] = dict([("input", seqs[i]), ("taps", [0])])
elif seqs[i].get("taps", None) is not None:
seqs[i]["taps"] = wrap_into_list(seqs[i]["taps"])
elif seqs[i].get("taps", None) is None:
......@@ -377,7 +376,7 @@ def scan(
if not isinstance(outs_info[i], dict):
# by default any output has a tap value of -1
outs_info[i] = OrderedDict([("initial", outs_info[i]), ("taps", [-1])])
outs_info[i] = dict([("initial", outs_info[i]), ("taps", [-1])])
elif (
outs_info[i].get("initial", None) is None
and outs_info[i].get("taps", None) is not None
......@@ -417,7 +416,7 @@ def scan(
else:
# if a None is provided as the output info we replace it
# with an empty OrdereDict() to simplify handling
outs_info[i] = OrderedDict()
outs_info[i] = {}
##
# Step 2. Generate inputs and outputs of the inner functions
......@@ -564,7 +563,7 @@ def scan(
mit_sot_inner_inputs = []
mit_sot_inner_slices = []
mit_sot_inner_outputs = []
mit_sot_return_steps = OrderedDict()
mit_sot_return_steps = {}
mit_sot_tap_array = []
mit_sot_rightOrder = []
......@@ -573,7 +572,7 @@ def scan(
sit_sot_inner_inputs = []
sit_sot_inner_slices = []
sit_sot_inner_outputs = []
sit_sot_return_steps = OrderedDict()
sit_sot_return_steps = {}
sit_sot_rightOrder = []
# go through outputs picking up time slices as needed
......@@ -764,9 +763,7 @@ def scan(
if condition is not None:
outputs.append(condition)
fake_nonseqs = [x.type() for x in non_seqs]
fake_outputs = clone_replace(
outputs, replace=OrderedDict(zip(non_seqs, fake_nonseqs))
)
fake_outputs = clone_replace(outputs, replace=dict(zip(non_seqs, fake_nonseqs)))
all_inputs = filter(
lambda x: (
isinstance(x, Variable)
......@@ -820,7 +817,7 @@ def scan(
n_outs = len(dummy_outputs)
if as_while:
n_outs = n_outs - 1
outs_info = [OrderedDict() for x in range(n_outs)]
outs_info = [{} for x in range(n_outs)]
# Step 5.1 Outputs with taps different then -1
......@@ -834,7 +831,7 @@ def scan(
sit_sot_inner_outputs.append(outputs[i])
# Step 5.3 Outputs that correspond to update rules of shared variables
givens = OrderedDict()
inner_replacements = {}
n_shared_outs = 0
shared_scan_inputs = []
shared_inner_inputs = []
......@@ -843,8 +840,10 @@ def scan(
for input in dummy_inputs:
if isinstance(input.variable, SharedVariable) and input.update:
new_var = safe_new(input.variable)
if getattr(input.variable, "name", None) is not None:
new_var.name = input.variable.name + "_copy"
if isinstance(new_var.type, TensorType):
sit_sot_inner_inputs.append(new_var)
sit_sot_scan_inputs.append(
......@@ -864,13 +863,13 @@ def scan(
# refers to the update rule with index `-1 - pos`.
sit_sot_rightOrder.append(-1 - len(sit_sot_shared))
sit_sot_shared.append(input.variable)
givens[input.variable] = new_var
inner_replacements[input.variable] = new_var
else:
shared_inner_inputs.append(new_var)
shared_scan_inputs.append(input.variable)
shared_inner_outputs.append(input.update)
givens[input.variable] = new_var
inner_replacements[input.variable] = new_var
n_shared_outs += 1
n_sit_sot = len(sit_sot_inner_inputs)
......@@ -878,7 +877,7 @@ def scan(
# Step 5.4 Outputs with no taps used in the input
n_nit_sot = 0
nit_sot_inner_outputs = []
nit_sot_return_steps = OrderedDict()
nit_sot_return_steps = {}
nit_sot_rightOrder = []
for i, out in enumerate(outs_info):
if "taps" not in out:
......@@ -905,7 +904,7 @@ def scan(
if (not isinstance(arg, SharedVariable) and not isinstance(arg, Constant))
]
givens.update(OrderedDict(zip(other_scan_args, other_inner_args)))
inner_replacements.update(dict(zip(other_scan_args, other_inner_args)))
if strict:
non_seqs_set = set(non_sequences if non_sequences is not None else [])
......@@ -939,11 +938,13 @@ def scan(
for arg in dummy_inputs
if (isinstance(arg.variable, SharedVariable) and not arg.update)
]
givens.update(OrderedDict(zip(other_shared_scan_args, other_shared_inner_args)))
inner_replacements.update(
dict(zip(other_shared_scan_args, other_shared_inner_args))
)
##
# Step 6. Re-order the outputs and clone them replacing things
# using the givens
# using `inner_replacements`
##
inner_inputs = (
inner_seqs
......@@ -964,10 +965,8 @@ def scan(
)
if condition is not None:
inner_outs.append(condition)
# NOTE: legacy code traversed GPU types
new_givens = givens
new_outs = clone_replace(inner_outs, replace=new_givens)
new_outs = clone_replace(inner_outs, replace=inner_replacements)
##
# Step 7. Create the Scan Op
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论