提交 1bd371ce authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #4378 from shabanian/bn

Disable test values in BNComposite and get_substream_rstates
...@@ -13,7 +13,6 @@ import numpy ...@@ -13,7 +13,6 @@ import numpy
import os import os
import re import re
import sys import sys
import traceback
import warnings import warnings
import theano import theano
...@@ -21,7 +20,6 @@ from theano import config ...@@ -21,7 +20,6 @@ from theano import config
import theano.gof.cc import theano.gof.cc
from six import itervalues from six import itervalues
from six.moves import StringIO
from theano.gof import graph from theano.gof import graph
from theano.gof import utils from theano.gof import utils
from theano.gof.cmodule import GCC_compiler from theano.gof.cmodule import GCC_compiler
...@@ -554,16 +552,7 @@ class PureOp(object): ...@@ -554,16 +552,7 @@ class PureOp(object):
detailed_err_msg = ( detailed_err_msg = (
"For compute_test_value, one input test value does not" "For compute_test_value, one input test value does not"
" have the requested type.\n") " have the requested type.\n")
tr = getattr(v.tag, 'trace', []) detailed_err_msg += utils.get_variable_trace_string(v)
if isinstance(tr, list) and len(tr) > 0:
detailed_err_msg += (
" \nBacktrace when that variable is created:\n")
# Print separate message for each element in the list
# of batcktraces
sio = StringIO()
for subtr in tr:
traceback.print_list(subtr, sio)
detailed_err_msg += str(sio.getvalue())
detailed_err_msg += ( detailed_err_msg += (
"\nThe error when converting the test value to that" "\nThe error when converting the test value to that"
...@@ -575,8 +564,8 @@ class PureOp(object): ...@@ -575,8 +564,8 @@ class PureOp(object):
e.args = ("\n".join(args),) e.args = ("\n".join(args),)
raise raise
return ret return ret
detailed_err_msg = utils.get_variable_trace_string(v)
raise AttributeError('%s has no test value' % v) raise AttributeError('%s has no test value %s' % (v, detailed_err_msg))
def __call__(self, *inputs, **kwargs): def __call__(self, *inputs, **kwargs):
""" """
...@@ -630,9 +619,11 @@ class PureOp(object): ...@@ -630,9 +619,11 @@ class PureOp(object):
(i, ins, node), stacklevel=2) (i, ins, node), stacklevel=2)
run_perform = False run_perform = False
elif config.compute_test_value == 'raise': elif config.compute_test_value == 'raise':
detailed_err_msg = utils.get_variable_trace_string(ins)
raise ValueError( raise ValueError(
'Cannot compute test value: input %i (%s) of Op %s missing default value' % 'Cannot compute test value: input %i (%s) of Op %s missing default value. %s' %
(i, ins, node)) (i, ins, node, detailed_err_msg))
elif config.compute_test_value == 'ignore': elif config.compute_test_value == 'ignore':
# silently skip test # silently skip test
run_perform = False run_perform = False
......
from __future__ import absolute_import, print_function, division from __future__ import absolute_import, print_function, division
import linecache import linecache
import sys import sys
import traceback
import numpy import numpy
from six import iteritems, integer_types, string_types from six import iteritems, integer_types, string_types
from six.moves import StringIO
from theano import config from theano import config
from theano.compat import OrderedDict, PY3 from theano.compat import OrderedDict, PY3
...@@ -99,6 +101,7 @@ def add_tag_trace(thing, user_line=None): ...@@ -99,6 +101,7 @@ def add_tag_trace(thing, user_line=None):
"theano/scan_module/", "theano\\scan_module\\", "theano/scan_module/", "theano\\scan_module\\",
"theano/sparse/", "theano\\sparse\\", "theano/sparse/", "theano\\sparse\\",
"theano/typed_list/", "theano\\typed_list\\"] "theano/typed_list/", "theano\\typed_list\\"]
tr = simple_extract_stack(limit=user_line, skips=skips) tr = simple_extract_stack(limit=user_line, skips=skips)
# Different python version use different sementic for # Different python version use different sementic for
# limit. python 2.7 include the call to extrack_stack. The -1 get # limit. python 2.7 include the call to extrack_stack. The -1 get
...@@ -111,6 +114,23 @@ def add_tag_trace(thing, user_line=None): ...@@ -111,6 +114,23 @@ def add_tag_trace(thing, user_line=None):
return thing return thing
def get_variable_trace_string(v):
sio = StringIO()
# For backward compatibility with old trace
tr = getattr(v.tag, 'trace', [])
if isinstance(tr, list) and len(tr) > 0:
print(" \nBacktrace when that variable is created:\n", file=sio)
# The isinstance is needed to handle old pickled trace
if isinstance(tr[0], tuple):
traceback.print_list(v.tag.trace, sio)
else:
# Print separate message for each element in the list of
# batcktraces
for subtr in tr:
traceback.print_list(subtr, sio)
return sio.getvalue()
def hashtype(self): def hashtype(self):
t = type(self) t = type(self)
return hash(t.__name__) ^ hash(t.__module__) return hash(t.__name__) ^ hash(t.__module__)
......
...@@ -3,17 +3,15 @@ from __future__ import absolute_import, print_function, division ...@@ -3,17 +3,15 @@ from __future__ import absolute_import, print_function, division
import six.moves.builtins as builtins import six.moves.builtins as builtins
import logging import logging
import time import time
import traceback
import warnings import warnings
import numpy # for numeric_grad import numpy # for numeric_grad
from six import itervalues from six import itervalues
from six.moves import StringIO
import theano import theano
from theano import gof from theano import gof
from theano.gof import Variable from theano.gof import utils, Variable
from theano.compat import OrderedDict, izip from theano.compat import OrderedDict, izip
from six.moves import xrange, reduce from six.moves import xrange, reduce
from theano.gof.null_type import NullType, null_type from theano.gof.null_type import NullType, null_type
...@@ -518,17 +516,7 @@ def grad(cost, wrt, consider_constant=None, ...@@ -518,17 +516,7 @@ def grad(cost, wrt, consider_constant=None,
elif disconnected_inputs == 'warn': elif disconnected_inputs == 'warn':
warnings.warn(message, stacklevel=2) warnings.warn(message, stacklevel=2)
elif disconnected_inputs == 'raise': elif disconnected_inputs == 'raise':
# Add the var trace message = utils.get_variable_trace_string(var)
tr = getattr(var.tag, 'trace', [])
if len(tr) > 0:
message += "\nBacktrace when the node is created:\n"
# Print separate message for each element in the list of batcktraces
sio = StringIO()
for subtr in tr:
traceback.print_list(subtr, sio)
message += str(sio.getvalue())
raise DisconnectedInputError(message) raise DisconnectedInputError(message)
else: else:
raise ValueError("Invalid value for keyword " raise ValueError("Invalid value for keyword "
......
...@@ -1205,6 +1205,7 @@ class MRG_RandomStreams(object): ...@@ -1205,6 +1205,7 @@ class MRG_RandomStreams(object):
self.rstate = multMatVect(self.rstate, A1p134, M1, A2p134, M2) self.rstate = multMatVect(self.rstate, A1p134, M1, A2p134, M2)
assert self.rstate.dtype == numpy.int32 assert self.rstate.dtype == numpy.int32
@theano.configparser.change_flags(compute_test_value='off')
def get_substream_rstates(self, n_streams, dtype, inc_rstate=True): def get_substream_rstates(self, n_streams, dtype, inc_rstate=True):
# TODO : need description for parameter and return # TODO : need description for parameter and return
""" """
......
...@@ -109,6 +109,21 @@ def test_consistency_randomstreams(): ...@@ -109,6 +109,21 @@ def test_consistency_randomstreams():
assert(numpy.allclose(samples, java_samples)) assert(numpy.allclose(samples, java_samples))
def test_get_substream_rstates():
try:
orig = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
n_streams = 100
dtype = 'float32'
rng = MRG_RandomStreams(numpy.random.randint(2147462579))
rng.get_substream_rstates(n_streams, dtype)
finally:
theano.config.compute_test_value = orig
def test_consistency_cpu_serial(): def test_consistency_cpu_serial():
""" """
Verify that the random numbers generated by mrg_uniform, serially, Verify that the random numbers generated by mrg_uniform, serially,
......
...@@ -7,6 +7,7 @@ from theano.scalar import add, sub, true_div, mul ...@@ -7,6 +7,7 @@ from theano.scalar import add, sub, true_div, mul
class BNComposite(Composite): class BNComposite(Composite):
init_param = ('dtype',) init_param = ('dtype',)
@theano.configparser.change_flags(compute_test_value='off')
def __init__(self, dtype): def __init__(self, dtype):
self.dtype = dtype self.dtype = dtype
x = theano.scalar.Scalar(dtype=dtype).make_variable() x = theano.scalar.Scalar(dtype=dtype).make_variable()
......
...@@ -6,6 +6,47 @@ import numpy ...@@ -6,6 +6,47 @@ import numpy
from theano.tensor.nnet.bn import batch_normalization from theano.tensor.nnet.bn import batch_normalization
def test_BNComposite():
try:
orig = theano.config.compute_test_value
theano.config.compute_test_value = 'raise'
def bn_ref(x, G, B, M, V):
n = (x - M) / V
return n * G + B
numpy.random.seed(1234)
X = 1 + numpy.random.random([10, 20]).astype('float32')
B = 1 + numpy.random.random([20]).astype('float32')
G = 1 + numpy.random.random([20]).astype('float32')
M = 1 + numpy.random.random([20]).astype('float32')
V = 1 + numpy.random.random([20]).astype('float32')
x = theano.tensor.matrix('x')
b = theano.tensor.vector('b')
g = theano.tensor.vector('g')
m = theano.tensor.vector('m')
v = theano.tensor.vector('v')
x.tag.test_value = numpy.random.rand(2, 2).astype(theano.config.floatX)
b.tag.test_value = numpy.random.rand(2).astype(theano.config.floatX)
g.tag.test_value = numpy.random.rand(2).astype(theano.config.floatX)
m.tag.test_value = numpy.random.rand(2).astype(theano.config.floatX)
v.tag.test_value = numpy.random.rand(2).astype(theano.config.floatX)
bn_ref_op = bn_ref(x, g, b, m, v)
f_ref = theano.function([x, b, g, m, v], [bn_ref_op])
res_ref = f_ref(X, G, B, M, V)
for mode in ['low_mem', 'high_mem']:
bn_op = batch_normalization(x, g, b, m, v, mode=mode)
f = theano.function([x, b, g, m, v], [bn_op])
res = f(X, G, B, M, V)
utt.assert_allclose(res_ref, res)
finally:
theano.config.compute_test_value = orig
def test_bn(): def test_bn():
def bn_ref(x, G, B, M, V): def bn_ref(x, G, B, M, V):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论