提交 c318153c authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Expand test

上级 a1530fe2
...@@ -5,6 +5,8 @@ import unittest ...@@ -5,6 +5,8 @@ import unittest
from nose.plugins.skip import SkipTest from nose.plugins.skip import SkipTest
from nose.plugins.attrib import attr from nose.plugins.attrib import attr
import numpy
import theano import theano
from theano.gof.link import PerformLinker from theano.gof.link import PerformLinker
from theano.gof.cc import CLinker, DualLinker, OpWiseCLinker from theano.gof.cc import CLinker, DualLinker, OpWiseCLinker
...@@ -364,28 +366,62 @@ def test_c_fail_error(): ...@@ -364,28 +366,62 @@ def test_c_fail_error():
assert 0 # test failed assert 0 # test failed
# Test bug reported on the mailing list by Alberto Orlandi
# https://groups.google.com/d/topic/theano-users/6dLaEqc2R6g/discussion
def test_shared_input_output(): def test_shared_input_output():
# Not sure yet why this case is special. # Test bug reported on the mailing list by Alberto Orlandi
# It may be that a shared variable is both an implicit input # https://groups.google.com/d/topic/theano-users/6dLaEqc2R6g/discussion
# and an output. # The shared variable is both an input and an output of the function.
inc = theano.tensor.iscalar('inc') inc = theano.tensor.iscalar('inc')
state = theano.shared(0) state = theano.shared(0)
state.name = 'state'
linker = theano.gof.CLinker() linker = theano.gof.CLinker()
mode = theano.Mode(linker=linker) mode = theano.Mode(linker=linker)
f = theano.function([inc], state, updates=[(state, state + inc)], f = theano.function([inc], state, updates=[(state, state + inc)],
mode=mode) mode=mode)
g = theano.function([inc], state, updates=[(state, state + inc)]) g = theano.function([inc], state, updates=[(state, state + inc)])
theano.printing.debugprint(f)
theano.printing.debugprint(g)
assert f(0) == g(0) == 0
import numpy # Initial value
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 0, (f0, g0)
# Increment state via f, returns the previous value.
f2 = f(2)
assert f2 == f0, (f2, f0)
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 2, (f0, g0)
# Increment state via g, returns the previous value
g3 = g(3)
assert g3 == g0, (g3, g0)
f0 = f(0)
g0 = g(0)
assert f0 == g0 == 5, (f0, g0)
vstate = theano.shared(numpy.zeros(3, dtype='int32')) vstate = theano.shared(numpy.zeros(3, dtype='int32'))
vstate.name = 'vstate'
fv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)], fv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)],
mode=mode) mode=mode)
gv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)]) gv = theano.function([inc], vstate, updates=[(vstate, vstate + inc)])
theano.printing.debugprint(fv)
theano.printing.debugprint(gv) # Initial value
assert numpy.all(fv(0) == gv(0) == 0), (fv(0), gv(0)) fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 0), fv0
assert numpy.all(gv0 == 0), gv0
# Increment state via f, returns the previous value.
fv2 = fv(2)
assert numpy.all(fv2 == fv0), (fv2, fv0)
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 2), fv0
assert numpy.all(gv0 == 2), gv0
# Increment state via g, returns the previous value
gv3 = gv(3)
assert numpy.all(gv3 == gv0), (gv3, gv0)
fv0 = fv(0)
gv0 = gv(0)
assert numpy.all(fv0 == 5), fv0
assert numpy.all(gv0 == 5), gv0
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论