提交 79ce5106 authored 作者: Virgile Andreani's avatar Virgile Andreani 提交者: Ricardo Vieira

Remove OrderedDict from tests/scan/test_basic

上级 643199e2
......@@ -13,7 +13,6 @@ import os
import pickle
import shutil
import sys
from collections import OrderedDict
from tempfile import mkdtemp
import numpy as np
......@@ -764,11 +763,9 @@ class TestScan:
b = shared(np.random.default_rng(utt.fetch_seed()).random((5, 4)))
def inner_func(a):
return a + 1, OrderedDict([(b, 2 * b)])
return a + 1, {b: 2 * b}
out, updates = scan(
inner_func, outputs_info=[OrderedDict([("initial", init_a)])], n_steps=1
)
out, updates = scan(inner_func, outputs_info=[{"initial": init_a}], n_steps=1)
out = out[-1]
assert out.type.ndim == a.type.ndim
assert updates[b].type.ndim == b.type.ndim
......@@ -934,7 +931,7 @@ class TestScan:
state = shared(v_state, "vstate")
def f_2():
return OrderedDict([(state, 2 * state)])
return {state: 2 * state}
n_steps = iscalar("nstep")
output, updates = scan(
......@@ -968,7 +965,7 @@ class TestScan:
X = shared(np.array(1))
out, updates = scan(
lambda: OrderedDict([(X, (X + 1))]),
lambda: {X: (X + 1)},
outputs_info=[],
non_sequences=[],
sequences=[],
......@@ -984,7 +981,7 @@ class TestScan:
y = shared(np.array(1))
out, updates = scan(
lambda: OrderedDict([(x, x + 1), (y, x)]),
lambda: {x: x + 1, y: x},
outputs_info=[],
non_sequences=[],
sequences=[],
......@@ -1914,7 +1911,7 @@ class TestScan:
shared_var = shared(np.float32(1.0))
def inner_fn():
return [], OrderedDict([(shared_var, shared_var + np.float32(1.0))])
return [], {shared_var: shared_var + np.float32(1.0)}
_, updates = scan(
inner_fn, n_steps=10, truncate_gradient=-1, go_backwards=False
......@@ -2746,7 +2743,7 @@ class TestExamples:
v1 = shared(np.ones(5, dtype=config.floatX))
v2 = shared(np.ones((5, 5), dtype=config.floatX))
shapef = function([W], expr, givens=OrderedDict([(initial, v1), (inpt, v2)]))
shapef = function([W], expr, givens={initial: v1, inpt: v2})
# First execution to cache n_steps
shapef(np.ones((5, 5), dtype=config.floatX))
......@@ -2755,7 +2752,7 @@ class TestExamples:
f = function(
[W, inpt],
d_cost_wrt_W,
givens=OrderedDict([(initial, shared(np.zeros(5)))]),
givens={initial: shared(np.zeros(5))},
)
rval = np.asarray([[5187989] * 5] * 5, dtype=config.floatX)
......@@ -2956,7 +2953,7 @@ class TestExamples:
seq = matrix()
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
outputs_info = [OrderedDict([("initial", initial_value), ("taps", [-4])]), None]
outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
results, updates = scan(fn=onestep, sequences=seq, outputs_info=outputs_info)
f = function([seq], results[1])
......@@ -2979,10 +2976,10 @@ class TestExamples:
seq = matrix()
initial_value = shared(np.zeros((4, 1), dtype=config.floatX))
outputs_info = [OrderedDict([("initial", initial_value), ("taps", [-4])]), None]
outputs_info = [{"initial": initial_value, "taps": [-4]}, None]
results, _ = scan(fn=onestep, sequences=seq, outputs_info=outputs_info)
sharedvar = shared(np.zeros((1, 1), dtype=config.floatX))
updates = OrderedDict([(sharedvar, results[0][-1:])])
updates = {sharedvar: results[0][-1:]}
f = function([seq], results[1], updates=updates)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论