提交 6eedb364 authored 作者: Benjamin Scellier's avatar Benjamin Scellier 提交者: Nicolas Ballas

file theano/gof/tests/test_cc.py

上级 9c60aa57
......@@ -2,7 +2,7 @@ from __future__ import absolute_import, print_function, division
from nose.plugins.skip import SkipTest
import numpy
import numpy as np
import theano
from theano.gof.link import PerformLinker
......@@ -211,16 +211,16 @@ def test_clinker_literal_cache():
A = theano.tensor.matrix()
input1 = theano.tensor.vector()
normal_svd = numpy.array([[5.936276e+01, -4.664007e-07, -2.56265e-06],
[-4.664007e-07, 9.468691e-01, -3.18862e-02],
[-2.562651e-06, -3.188625e-02, 1.05226e+00]],
dtype=theano.config.floatX)
normal_svd = np.array([[5.936276e+01, -4.664007e-07, -2.56265e-06],
[-4.664007e-07, 9.468691e-01, -3.18862e-02],
[-2.562651e-06, -3.188625e-02, 1.05226e+00]],
dtype=theano.config.floatX)
orientationi = numpy.array([59.36276866, 1.06116353, 0.93797339],
dtype=theano.config.floatX)
orientationi = np.array([59.36276866, 1.06116353, 0.93797339],
dtype=theano.config.floatX)
for out1 in [A - input1[0] * numpy.identity(3),
input1[0] * numpy.identity(3)]:
for out1 in [A - input1[0] * np.identity(3),
input1[0] * np.identity(3)]:
benchmark = theano.function(
inputs=[A, input1],
outputs=[out1],
......@@ -421,7 +421,7 @@ def test_shared_input_output():
g0 = g(0)
assert f0 == g0 == 5, (f0, g0)
vstate = theano.shared(numpy.zeros(3, dtype='int32'))
vstate = theano.shared(np.zeros(3, dtype='int32'))
vstate.name = 'vstate'
fv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)],
mode=mode)
......@@ -430,21 +430,21 @@ def test_shared_input_output():
# Initial value
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 0), fv0
assert numpy.all(gv0 == 0), gv0
assert np.all(fv0 == 0), fv0
assert np.all(gv0 == 0), gv0
# Increment state via f, returns the previous value.
fv2 = fv(2)
assert numpy.all(fv2 == fv0), (fv2, fv0)
assert np.all(fv2 == fv0), (fv2, fv0)
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 2), fv0
assert numpy.all(gv0 == 2), gv0
assert np.all(fv0 == 2), fv0
assert np.all(gv0 == 2), gv0
# Increment state via g, returns the previous value
gv3 = gv(3)
assert numpy.all(gv3 == gv0), (gv3, gv0)
assert np.all(gv3 == gv0), (gv3, gv0)
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 5), fv0
assert numpy.all(gv0 == 5), gv0
assert np.all(fv0 == 5), fv0
assert np.all(gv0 == 5), gv0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论