提交 037fd298 authored 作者: Samira Shabanian's avatar Samira Shabanian

Disable test values in BNComposite and get_substream_rstates

上级 a536464a
......@@ -576,7 +576,11 @@ class PureOp(object):
raise
return ret
raise AttributeError('%s has no test value' % v)
sio = StringIO()
for subtr in getattr(v.tag, "trace", []):
traceback.print_list(subtr, sio)
detailed_err_msg = sio.getvalue()
raise AttributeError('%s has no test value %s' % (v, detailed_err_msg))
def __call__(self, *inputs, **kwargs):
"""
......@@ -630,9 +634,14 @@ class PureOp(object):
(i, ins, node), stacklevel=2)
run_perform = False
elif config.compute_test_value == 'raise':
sio = StringIO()
for subtr in getattr(ins.tag, "trace", []):
traceback.print_list(subtr, sio)
detailed_err_msg = sio.getvalue()
raise ValueError(
'Cannot compute test value: input %i (%s) of Op %s missing default value' %
(i, ins, node))
'Cannot compute test value: input %i (%s) of Op %s missing default value. %s' %
(i, ins, node, detailed_err_msg))
elif config.compute_test_value == 'ignore':
# silently skip test
run_perform = False
......
......@@ -99,6 +99,7 @@ def add_tag_trace(thing, user_line=None):
"theano/scan_module/", "theano\\scan_module\\",
"theano/sparse/", "theano\\sparse\\",
"theano/typed_list/", "theano\\typed_list\\"]
skips = []
tr = simple_extract_stack(limit=user_line, skips=skips)
# Different python version use different sementic for
# limit. python 2.7 include the call to extrack_stack. The -1 get
......
......@@ -1205,6 +1205,7 @@ class MRG_RandomStreams(object):
self.rstate = multMatVect(self.rstate, A1p134, M1, A2p134, M2)
assert self.rstate.dtype == numpy.int32
@theano.configparser.change_flags(compute_test_value='off')
def get_substream_rstates(self, n_streams, dtype, inc_rstate=True):
# TODO : need description for parameter and return
"""
......
......@@ -109,6 +109,21 @@ def test_consistency_randomstreams():
assert(numpy.allclose(samples, java_samples))
def test_get_substream_rstates():
try:
orig = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
n_streams = 100
dtype = 'float32'
rng = MRG_RandomStreams(numpy.random.randint(2147462579))
rng.get_substream_rstates(n_streams, dtype)
finally:
theano.config.compute_test_value = orig
def test_consistency_cpu_serial():
"""
Verify that the random numbers generated by mrg_uniform, serially,
......
......@@ -7,6 +7,7 @@ from theano.scalar import add, sub, true_div, mul
class BNComposite(Composite):
init_param = ('dtype',)
@theano.configparser.change_flags(compute_test_value='off')
def __init__(self, dtype):
self.dtype = dtype
x = theano.scalar.Scalar(dtype=dtype).make_variable()
......
......@@ -6,6 +6,47 @@ import numpy
from theano.tensor.nnet.bn import batch_normalization
def test_BNComposite():
try:
orig = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
def bn_ref(x, G, B, M, V):
n = (x - M) / V
return n * G + B
numpy.random.seed(1234)
X = 1 + numpy.random.random([10, 20]).astype('float32')
B = 1 + numpy.random.random([20]).astype('float32')
G = 1 + numpy.random.random([20]).astype('float32')
M = 1 + numpy.random.random([20]).astype('float32')
V = 1 + numpy.random.random([20]).astype('float32')
x = theano.tensor.matrix('x')
b = theano.tensor.vector('b')
g = theano.tensor.vector('g')
m = theano.tensor.vector('m')
v = theano.tensor.vector('v')
x.tag.test_value = numpy.random.rand(2, 2)
b.tag.test_value = numpy.random.rand(2)
g.tag.test_value = numpy.random.rand(2)
m.tag.test_value = numpy.random.rand(2)
v.tag.test_value = numpy.random.rand(2)
bn_ref_op = bn_ref(x, g, b, m, v)
f_ref = theano.function([x, b, g, m, v], [bn_ref_op])
res_ref = f_ref(X, G, B, M, V)
for mode in ['low_mem', 'high_mem']:
bn_op = batch_normalization(x, g, b, m, v, mode=mode)
f = theano.function([x, b, g, m, v], [bn_op])
res = f(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
finally:
theano.config.compute_test_value = orig
def test_bn():
def bn_ref(x, G, B, M, V):
......@@ -25,6 +66,12 @@ def test_bn():
m = theano.tensor.vector('m')
v = theano.tensor.vector('v')
x.tag.test_value = numpy.random.rand(2, 2)
b.tag.test_value = numpy.random.rand(2)
g.tag.test_value = numpy.random.rand(2)
m.tag.test_value = numpy.random.rand(2)
v.tag.test_value = numpy.random.rand(2)
bn_ref_op = bn_ref(x, g, b, m, v)
f_ref = theano.function([x, b, g, m, v], [bn_ref_op])
res_ref = f_ref(X, G, B, M, V)
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论