提交 177124ba authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Remove OrderedDict from scan/op

上级 c7a99b60
...@@ -46,7 +46,6 @@ relies on the following elements to work properly : ...@@ -46,7 +46,6 @@ relies on the following elements to work properly :
import dataclasses import dataclasses
import logging import logging
import time import time
from collections import OrderedDict
from collections.abc import Callable, Iterable from collections.abc import Callable, Iterable
from copy import copy from copy import copy
from itertools import chain, product from itertools import chain, product
...@@ -2188,7 +2187,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2188,7 +2187,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# corresponding outer inputs that the Scan would use as input for # corresponding outer inputs that the Scan would use as input for
# any given iteration. For simplicity, we use iteration 0. # any given iteration. For simplicity, we use iteration 0.
inner_ins_shapes = [] inner_ins_shapes = []
out_equivalent = OrderedDict() out_equivalent = {}
# The two following blocks are commented as it cause in some # The two following blocks are commented as it cause in some
# cases extra scans in the graph. See gh-XXX for the # cases extra scans in the graph. See gh-XXX for the
...@@ -2469,7 +2468,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2469,7 +2468,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
if (x in diff_inputs) if (x in diff_inputs)
and get_inp_idx(self_inputs.index(x)) in connected_inputs and get_inp_idx(self_inputs.index(x)) in connected_inputs
] ]
gmp = OrderedDict() gmp = {}
# Required in case there is a pair of variables X and Y, with X # Required in case there is a pair of variables X and Y, with X
# used to compute Y, for both of which there is an external # used to compute Y, for both of which there is an external
...@@ -2478,7 +2477,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2478,7 +2477,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
# it will be the sum of the external gradient signal and the # it will be the sum of the external gradient signal and the
# gradient obtained by propagating Y's external gradient signal # gradient obtained by propagating Y's external gradient signal
# to X. # to X.
known_grads = OrderedDict([(k.copy(), v) for (k, v) in known_grads.items()]) known_grads = {k.copy(): v for (k, v) in known_grads.items()}
grads = grad( grads = grad(
cost=None, cost=None,
...@@ -2548,7 +2547,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph): ...@@ -2548,7 +2547,7 @@ class Scan(Op, ScanMethodsMixin, HasInnerGraph):
dC_dXt = safe_new(dC_douts[idx][0]) dC_dXt = safe_new(dC_douts[idx][0])
dC_dXts.append(dC_dXt) dC_dXts.append(dC_dXt)
known_grads = OrderedDict() known_grads = {}
dc_dxts_idx = 0 dc_dxts_idx = 0
for i in range(len(diff_outputs)): for i in range(len(diff_outputs)):
if i < idx_nitsot_start or i >= idx_nitsot_end: if i < idx_nitsot_start or i >= idx_nitsot_end:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论