提交 1bd371ce authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4378 from shabanian/bn

Disable test values in BNComposite and get_substream_rstates
......@@ -13,7 +13,6 @@ import numpy
import os
import re
import sys
import traceback
import warnings
import theano
......@@ -21,7 +20,6 @@ from theano import config
import theano.gof.cc
from six import itervalues
from six.moves import StringIO
from theano.gof import graph
from theano.gof import utils
from theano.gof.cmodule import GCC_compiler
......@@ -554,16 +552,7 @@ class PureOp(object):
detailed_err_msg = (
"For compute_test_value, one input test value does not"
" have the requested type.\n")
tr = getattr(v.tag, 'trace', [])
if isinstance(tr, list) and len(tr) > 0:
detailed_err_msg += (
" \nBacktrace when that variable is created:\n")
# Print separate message for each element in the list
# of batcktraces
sio = StringIO()
for subtr in tr:
traceback.print_list(subtr, sio)
detailed_err_msg += str(sio.getvalue())
detailed_err_msg += utils.get_variable_trace_string(v)
detailed_err_msg += (
"\nThe error when converting the test value to that"
......@@ -575,8 +564,8 @@ class PureOp(object):
e.args = ("\n".join(args),)
raise
return ret
raise AttributeError('%s has no test value' % v)
detailed_err_msg = utils.get_variable_trace_string(v)
raise AttributeError('%s has no test value %s' % (v, detailed_err_msg))
def __call__(self, *inputs, **kwargs):
"""
......@@ -630,9 +619,11 @@ class PureOp(object):
(i, ins, node), stacklevel=2)
run_perform = False
elif config.compute_test_value == 'raise':
detailed_err_msg = utils.get_variable_trace_string(ins)
raise ValueError(
'Cannot compute test value: input %i (%s) of Op %s missing default value' %
(i, ins, node))
'Cannot compute test value: input %i (%s) of Op %s missing default value. %s' %
(i, ins, node, detailed_err_msg))
elif config.compute_test_value == 'ignore':
# silently skip test
run_perform = False
......
from __future__ import absolute_import, print_function, division
import linecache
import sys
import traceback
import numpy
from six import iteritems, integer_types, string_types
from six.moves import StringIO
from theano import config
from theano.compat import OrderedDict, PY3
......@@ -99,6 +101,7 @@ def add_tag_trace(thing, user_line=None):
"theano/scan_module/", "theano\\scan_module\\",
"theano/sparse/", "theano\\sparse\\",
"theano/typed_list/", "theano\\typed_list\\"]
tr = simple_extract_stack(limit=user_line, skips=skips)
# Different python version use different sementic for
# limit. python 2.7 include the call to extrack_stack. The -1 get
......@@ -111,6 +114,23 @@ def add_tag_trace(thing, user_line=None):
return thing
def get_variable_trace_string(v):
sio = StringIO()
# For backward compatibility with old trace
tr = getattr(v.tag, 'trace', [])
if isinstance(tr, list) and len(tr) > 0:
print(" \nBacktrace when that variable is created:\n", file=sio)
# The isinstance is needed to handle old pickled trace
if isinstance(tr[0], tuple):
traceback.print_list(v.tag.trace, sio)
else:
# Print separate message for each element in the list of
# batcktraces
for subtr in tr:
traceback.print_list(subtr, sio)
return sio.getvalue()
def hashtype(self):
t = type(self)
return hash(t.__name__) ^ hash(t.__module__)
......
......@@ -3,17 +3,15 @@ from __future__ import absolute_import, print_function, division
import six.moves.builtins as builtins
import logging
import time
import traceback
import warnings
import numpy # for numeric_grad
from six import itervalues
from six.moves import StringIO
import theano
from theano import gof
from theano.gof import Variable
from theano.gof import utils, Variable
from theano.compat import OrderedDict, izip
from six.moves import xrange, reduce
from theano.gof.null_type import NullType, null_type
......@@ -518,17 +516,7 @@ def grad(cost, wrt, consider_constant=None,
elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=2)
elif disconnected_inputs == 'raise':
# Add the var trace
tr = getattr(var.tag, 'trace', [])
if len(tr) > 0:
message += "\nBacktrace when the node is created:\n"
# Print separate message for each element in the list of batcktraces
sio = StringIO()
for subtr in tr:
traceback.print_list(subtr, sio)
message += str(sio.getvalue())
message = utils.get_variable_trace_string(var)
raise DisconnectedInputError(message)
else:
raise ValueError("Invalid value for keyword "
......
......@@ -1205,6 +1205,7 @@ class MRG_RandomStreams(object):
self.rstate = multMatVect(self.rstate, A1p134, M1, A2p134, M2)
assert self.rstate.dtype == numpy.int32
@theano.configparser.change_flags(compute_test_value='off')
def get_substream_rstates(self, n_streams, dtype, inc_rstate=True):
# TODO : need description for parameter and return
"""
......
......@@ -109,6 +109,21 @@ def test_consistency_randomstreams():
assert(numpy.allclose(samples, java_samples))
def test_get_substream_rstates():
try:
orig = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
n_streams = 100
dtype = 'float32'
rng = MRG_RandomStreams(numpy.random.randint(2147462579))
rng.get_substream_rstates(n_streams, dtype)
finally:
theano.config.compute_test_value = orig
def test_consistency_cpu_serial():
"""
Verify that the random numbers generated by mrg_uniform, serially,
......
......@@ -7,6 +7,7 @@ from theano.scalar import add, sub, true_div, mul
class BNComposite(Composite):
init_param = ('dtype',)
@theano.configparser.change_flags(compute_test_value='off')
def __init__(self, dtype):
self.dtype = dtype
x = theano.scalar.Scalar(dtype=dtype).make_variable()
......
......@@ -6,6 +6,47 @@ import numpy
from theano.tensor.nnet.bn import batch_normalization
def test_BNComposite():
try:
orig = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
def bn_ref(x, G, B, M, V):
n = (x - M) / V
return n * G + B
numpy.random.seed(1234)
X = 1 + numpy.random.random([10, 20]).astype('float32')
B = 1 + numpy.random.random([20]).astype('float32')
G = 1 + numpy.random.random([20]).astype('float32')
M = 1 + numpy.random.random([20]).astype('float32')
V = 1 + numpy.random.random([20]).astype('float32')
x = theano.tensor.matrix('x')
b = theano.tensor.vector('b')
g = theano.tensor.vector('g')
m = theano.tensor.vector('m')
v = theano.tensor.vector('v')
x.tag.test_value = numpy.random.rand(2, 2).astype(theano.config.floatX)
b.tag.test_value = numpy.random.rand(2).astype(theano.config.floatX)
g.tag.test_value = numpy.random.rand(2).astype(theano.config.floatX)
m.tag.test_value = numpy.random.rand(2).astype(theano.config.floatX)
v.tag.test_value = numpy.random.rand(2).astype(theano.config.floatX)
bn_ref_op = bn_ref(x, g, b, m, v)
f_ref = theano.function([x, b, g, m, v], [bn_ref_op])
res_ref = f_ref(X, G, B, M, V)
for mode in ['low_mem', 'high_mem']:
bn_op = batch_normalization(x, g, b, m, v, mode=mode)
f = theano.function([x, b, g, m, v], [bn_op])
res = f(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
finally:
theano.config.compute_test_value = orig
def test_bn():
def bn_ref(x, G, B, M, V):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论