提交 79ce5106 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Remove OrderedDict from tests/scan/test_basic

上级 643199e2
...@@ -13,7 +13,6 @@ import os ...@@ -13,7 +13,6 @@ import os
import pickle import pickle
import shutil import shutil
import sys import sys
from collections import OrderedDict
from tempfile import mkdtemp from tempfile import mkdtemp
import numpy as np import numpy as np
...@@ -764,11 +763,9 @@ class TestScan: ...@@ -764,11 +763,9 @@ class TestScan:
b = shared(np.random.default_rng(utt.fetch_seed()).random((5, 4))) b = shared(np.random.default_rng(utt.fetch_seed()).random((5, 4)))
def inner_func(a): def inner_func(a):
return a + 1, OrderedDict([(b, 2 * b)]) return a + 1, {b: 2 * b}
out, updates = scan( out, updates = scan(inner_func, outputs_info=[{"initial": init_a}], n_steps=1)
inner_func, outputs_info=[OrderedDict([("initial", init_a)])], n_steps=1
)
out = out[-1] out = out[-1]
assert out.type.ndim == a.type.ndim assert out.type.ndim == a.type.ndim
assert updates[b].type.ndim == b.type.ndim assert updates[b].type.ndim == b.type.ndim
...@@ -934,7 +931,7 @@ class TestScan: ...@@ -934,7 +931,7 @@ class TestScan:
state = shared(v_state, "vstate") state = shared(v_state, "vstate")
def f_2(): def f_2():
return OrderedDict([(state, 2 * state)]) return {state: 2 * state}
n_steps = iscalar("nstep") n_steps = iscalar("nstep")
output, updates = scan( output, updates = scan(
...@@ -968,7 +965,7 @@ class TestScan: ...@@ -968,7 +965,7 @@ class TestScan:
X = shared(np.array(1)) X = shared(np.array(1))
out, updates = scan( out, updates = scan(
lambda: OrderedDict([(X, (X + 1))]), lambda: {X: (X + 1)},
outputs_info=[], outputs_info=[],
non_sequences=[], non_sequences=[],
sequences=[], sequences=[],
...@@ -984,7 +981,7 @@ class TestScan: ...@@ -984,7 +981,7 @@ class TestScan:
y = shared(np.array(1)) y = shared(np.array(1))
out, updates = scan( out, updates = scan(
lambda: OrderedDict([(x, x + 1), (y, x)]), lambda: {x: x + 1, y: x},
outputs_info=[], outputs_info=[],
non_sequences=[], non_sequences=[],
sequences=[], sequences=[],
...@@ -1914,7 +1911,7 @@ class TestScan: ...@@ -1914,7 +1911,7 @@ class TestScan:
shared_var = shared(np.float32(1.0)) shared_var = shared(np.float32(1.0))
def inner_fn(): def inner_fn():
return [], OrderedDict([(shared_var, shared_var + np.float32(1.0))]) return [], {shared_var: shared_var + np.float32(1.0)}
_, updates = scan( _, updates = scan(
inner_fn, n_steps=10, truncate_gradient=-1, go_backwards=False inner_fn, n_steps=10, truncate_gradient=-1, go_backwards=False
...@@ -2746,7 +2743,7 @@ class TestExamples: ...@@ -2746,7 +2743,7 @@ class TestExamples:
v1 = shared(np.ones(5, dtype=config.floatX)) v1 = shared(np.ones(5, dtype=config.floatX))
v2 = shared(np.ones((5, 5), dtype=config.floatX)) v2 = shared(np.ones((5, 5), dtype=config.floatX))
shapef = function([W], expr, givens=OrderedDict([(initial, v1), (inpt, v2)])) shapef = function([W], expr, givens={initial: v1, inpt: v2})
# First execution to cache n_steps # First execution to cache n_steps
shapef(np.ones((5, 5), dtype=config.floatX)) shapef(np.ones((5, 5), dtype=config.floatX))
...@@ -2755,7 +2752,7 @@ class TestExamples: ...@@ -2755,7 +2752,7 @@ class TestExamples:
f = function( f = function(
[W, inpt], [W, inpt],
d_cost_wrt_W, d_cost_wrt_W,
givens=OrderedDict([(initial, shared(np.zeros(5)))]), givens={initial: shared(np.zeros(5))},
) )
rval = np.asarray([[5187989] * 5] * 5, dtype=config.floatX) rval = np.asarray([[5187989] * 5] * 5, dtype=config.floatX)
...@@ -2956,7 +2953,7 @@ class TestExamples: ...@@ -2956,7 +2953,7 @@ class TestExamples:
seq = matrix() seq = matrix()
initial_value = shared(np.zeros((4, 1), dtype=config.floatX)) initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
outputs_info = [OrderedDict([("initial", initial_value), ("taps", [-4])]), None] outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
results, updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info) results, updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info)
f = function([seq], results[1]) f = function([seq], results[1])
...@@ -2979,10 +2976,10 @@ class TestExamples: ...@@ -2979,10 +2976,10 @@ class TestExamples:
seq = matrix() seq = matrix()
initial_value = shared(np.zeros((4, 1), dtype=config.floatX)) initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
outputs_info = [OrderedDict([("initial", initial_value), ("taps", [-4])]), None] outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
results, _ = scan(fn=onestep, sequences=seq, outputs_info=outputs_info) results, _ = scan(fn=onestep, sequences=seq, outputs_info=outputs_info)
sharedvar = shared(np.zeros((1, 1), dtype=config.floatX)) sharedvar = shared(np.zeros((1, 1), dtype=config.floatX))
updates = OrderedDict([(sharedvar, results[0][-1:])]) updates = {sharedvar: results[0][-1:]}
f = function([seq], results[1], updates=updates) f = function([seq], results[1], updates=updates)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论