提交 9dca9277 authored 作者: gdesjardins's avatar gdesjardins

* Changed compute_test_value options to ('off','ignore','warn','err')

* use warnings module to ... raise warnings!
上级 817d0cbd
......@@ -237,5 +237,5 @@ AddConfigVar('warn.sum_div_dimshuffle_bug',
AddConfigVar('compute_test_value',
"If 'True', Theano will run each op at graph build time, using Constants, SharedVariables and the tag 'test_value' as inputs to the function. This helps the user track down problems in the graph before it gets optimized.",
EnumStr('False', 'True', 'warn', 'err'),
EnumStr('off', 'ignore', 'warn', 'raise'),
in_c_key=False)
......@@ -10,6 +10,7 @@ from theano import config
import graph
import numpy
import utils
import warnings
import logging
from theano import config
from env import Env
......@@ -327,7 +328,7 @@ class PureOp(object):
node = self.make_node(*inputs, **kwargs)
self.add_tag_trace(node)
if config.compute_test_value != 'False':
if config.compute_test_value != 'off':
# avoid circular import
from theano.compile.sharedvalue import SharedVariable
run_perform = True
......@@ -345,14 +346,15 @@ class PureOp(object):
else:
# no test-value was specified, act accordingly
if config.compute_test_value == 'warn':
# TODO: use warnings.warn, http://docs.python.org/library/warnings.html#warnings.warn
print >>sys.stderr, ('Warning, Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node))
warnings.warn('Warning, Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node), stacklevel=2)
run_perform = False
elif config.compute_test_value == 'err':
elif config.compute_test_value == 'raise':
raise ValueError('Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node))
else:
elif config.compute_test_value == 'ignore':
# silently skip test
run_perform = False
else:
raise ValueError('%s is invalid for option config.compute_Test_value' % config.compute_test_value)
# if all inputs have test-values, run the actual op
if run_perform:
......@@ -371,10 +373,8 @@ class PureOp(object):
# for a certain Op.
#TODO: use the c_thunk?
if config.compute_test_value == 'warn':
# TODO: use warnings.warn
print >>sys.stderr, 'Warning, in compute_test_value:', type(e)
print >>sys.stderr, e
elif config.compute_test_value == 'err':
warnings.warn('Warning, in compute_test_value:' + type(e), stacklevel=2)
elif config.compute_test_value == 'raise':
raise
if self.default_output is not None:
......
......@@ -2,9 +2,11 @@ import numpy
import unittest
import theano
import warnings
from theano import config
from theano import tensor as T
from theano.tensor.basic import _allclose
from theano.scan_module import scan
class TestComputeTestValue(unittest.TestCase):
......@@ -12,7 +14,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_variable_only(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'True'
theano.config.compute_test_value = 'raise'
x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX)
......@@ -41,22 +43,25 @@ class TestComputeTestValue(unittest.TestCase):
y.tag.test_value = numpy.random.rand(4,5).astype(config.floatX)
# should skip computation of test value
theano.config.compute_test_value = 'False'
theano.config.compute_test_value = 'off'
z = T.dot(x,y)
assert not hasattr(z.tag, 'test_value')
# should fail one or another when flag is set
theano.config.compute_test_value = 'warn'
self.assertRaises(Warning, T.dot, x, y)
theano.config.compute_test_value = 'err'
# should fail when asked by user
theano.config.compute_test_value = 'raise'
self.assertRaises(ValueError, T.dot, x, y)
# test that a warning is raised if required
theano.config.compute_test_value = 'warn'
warnings.simplefilter('error', UserWarning)
self.assertRaises(UserWarning, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def test_string_var(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'True'
theano.config.compute_test_value = 'raise'
x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX)
......@@ -85,7 +90,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_shared(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'True'
theano.config.compute_test_value = 'raise'
x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX)
......@@ -106,7 +111,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_ndarray(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'True'
theano.config.compute_test_value = 'raise'
x = numpy.random.rand(2,3).astype(config.floatX)
y = theano.shared(numpy.random.rand(3,6).astype(config.floatX), 'y')
......@@ -126,7 +131,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_constant(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'True'
theano.config.compute_test_value = 'raise'
x = T.constant(numpy.random.rand(2,3), dtype=config.floatX)
y = theano.shared(numpy.random.rand(3,6).astype(config.floatX), 'y')
......@@ -146,7 +151,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_incorrect_type(self):
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'True'
theano.config.compute_test_value = 'raise'
x = T.fmatrix('x')
# Incorrect dtype (float64) for test_value
......@@ -157,3 +162,32 @@ class TestComputeTestValue(unittest.TestCase):
self.assertRaises(TypeError, T.dot, x, y)
finally:
theano.config.compute_test_value = orig_compute_test_value
def notest_scan(self):
"""
Do not run this test as the compute_test_value mechanism is known not to work with Scan.
TODO: fix scan to work with compute_test_value
"""
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'raise'
k = T.iscalar("k")
A = T.vector("A")
k.tag.test_value = 3
A.tag.test_value = numpy.random.rand(5)
def fx(prior_result, A):
return prior_results * A
# Symbolic description of the result
result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k)
# We only care about A**k, but scan has provided us with A**1 through A**k.
# Discard the values that we don't care about. Scan is smart enough to
# notice this and not waste memory saving them.
final_result = result[-1]
assert hasattr(final_result.tag, 'test_value')
finally:
theano.config.compute_test_value = orig_compute_test_value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论