提交 937e5140 authored 作者: James Bergstra's avatar James Bergstra

more test cases for function with state, kw inputs.

上级 5b6707c7
......@@ -130,6 +130,7 @@ class T_fast_compute(unittest.TestCase):
import tensor as T
import random
import numpy as N
class T_OpFromGraph(unittest.TestCase):
......@@ -169,21 +170,52 @@ class T_OpFromGraph(unittest.TestCase):
class T_state(unittest.TestCase):
def test_accumulator():
def test_accumulator(self):
"""Test low-level interface with state."""
x = T.scalar('x')
s = T.scalar('s')
fn, states = theano.function_states(inputs = [x], outputs = [], states = [(s, 0, s+x)])
fn, states = program_states(inputs = [x], outputs = [], states = [(s, 0, s+x)])
sum = 0
for inc in [1, 4, 5,23, -324]:
sum += inc
fn(inc)
fn.run([inc], states)
assert sum == states[0].value
def test_perceptron():
def test_misc0(self):
fn_inc, states_inc = function_states(\
inputs = [x], outputs = [], states = [(s, 0, s+x)])
fn_inc2, states_inc2 = function_states(\
inputs = [x], outputs = [], states = [(s, 0, s+x)])
fn_inc_copy = copy.copy(fn_inc) #USE fn copy
# run() is like __call__, but requires an explicit state argument
fn_inc.run([5], states_inc) #run on own state object
fn_inc2.run([3], states_inc) #run on compatible state object
assert states_inc[0].value == 8
states_inc_copy = copy.copy(states_inc) #USE state copy
fn_inc_copy.run([2], states_inc_copy)
assert states_inc[0].value == 10 #compatible
fn_dec, states_dec = function_states(\
inputs = [x], outputs = [], states = [(s, states_inc[0], s-x)])
try:
fn_inc.run([5], states_dec) # wrong kind of state for given program
self.fail("fn accepted an invalid state argument")
except SpecificException:
raise NotImplementedError() #TODO
except Exception:
self.fail("fn accepted an invalid state argument")
def test_perceptron(self):
"""Test high-level state interface."""
mu0 = numpy.array([1.0,0.0])
......@@ -192,11 +224,11 @@ class T_state(unittest.TestCase):
si1 = numpy.ones_like(mu1) #unit variance
#implicit internal state
label = T.random.bernoulli(0.5)
label = random.bernoulli(0.5)
#implicit internal state for each DiagGaussian
x = label * T.random.DiagGaussian(mu0, si0) \
+ (1 - label) * T.random.DiagGaussian(mu1,si1)
x = label * random.DiagGaussian(mu0, si0) \
+ (1 - label) * random.DiagGaussian(mu1,si1)
w = T.tensor.dvector()
b = T.tensor.dscalar()
......@@ -236,8 +268,106 @@ class T_state(unittest.TestCase):
print d0
print 'errs =', errs
def test_shared(self):
"""Test shared r/w state."""
if __name__ == '__main__':
x = T.scalar('x')
s = T.scalar('s')
unittest.main()
fn_inc, states_inc = function_states(\
inputs = [x], outputs = [], states = [(s, 0, s+x)])
fn_dec, states_dec = function_states(\
inputs = [x], outputs = [], states = [(s, states_inc[0], s-x)])
sum = 0
for inc in [1, 4, 5,23, -324]:
sum += inc
fn_inc.run([inc], states_inc)
assert sum == states_inc[0].value
a = sum
for inc in [1, 4, 5,23, -324]:
sum -= inc
fn_dec(inc)
assert sum == 0
assert states_inc[0].value == sum
for inc in [1, 4, 5,23, -324]:
sum -= inc
fn_dec(inc)
assert sum == -a
assert states_inc[0].value == sum
class T_dict_interface(unittest.TestCase):
def test_keyword(self):
x = T.scalar('x')
y = T.scalar('y')
s = T.scalar('s')
fn = function(input_kw = {'a':x, 'b':y}, outputs = [], state = {'s':(s, 0, s+x/y)})
try:
fn(1, 1)
self.fail("non-keyword call accepted!")
except SpecificException:
raise NotImplementedError()
except Exception:
self.fail("non-keyword call accepted!")
try:
fn(a=1)
self.fail("incomplete call accepted!")
except SpecificException:
raise NotImplementedError()
except Exception:
self.fail("incomplete call accepted!")
try:
fn(a=1, b=1, c=1)
self.fail("overcomplete call accepted!")
except SpecificException:
raise NotImplementedError()
except Exception:
self.fail("overcomplete call accepted!")
def test_aliased_state(self):
"""Test keyword input and copy."""
x = T.scalar('x')
y = T.scalar('y')
s = T.scalar('s')
fn = function(input_kw = {'a':x, 'b':y}, outputs = [], state = {'s':(s, 0, s+x/y)})
fn2 = fn.copy()
fn3 = fn.copy()
fn(a=2, b=5)
fn2(a=5, b=2)
fn3(b=2, a=5)
assert fn.state['s'] == 2.0/5
assert fn2.state['s'] == 5.0/2
assert fn3.state['s'] == 5.0/2
#fn and fn3 use the same sort of state, so this is OK.
fn3.state = fn.state
fn.state['s'] = 0
fn(a=1, b=1) #increment the shared state
assert fn3.state['s'] == 1
fn3(a=-1, b=1) #decrement the shared state
assert fn.state['s'] == 0
if __name__ == '__main__':
if 1:
unittest.main()
else:
testcases = [T_dict_interface, T_state]
#<testsuite boilerplate>
testloader = unittest.TestLoader()
suite = unittest.TestSuite()
for testcase in testcases:
suite.addTest(testloader.loadTestsFromTestCase(testcase))
unittest.TextTestRunner(verbosity=2).run(suite)
#</boilerplate>
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论