提交 9dca9277 authored 作者: gdesjardins's avatar gdesjardins

* Changed compute_test_value options to ('off','ignore','warn','err')

* use warnings module to ... raise warnings!
上级 817d0cbd
...@@ -237,5 +237,5 @@ AddConfigVar('warn.sum_div_dimshuffle_bug', ...@@ -237,5 +237,5 @@ AddConfigVar('warn.sum_div_dimshuffle_bug',
AddConfigVar('compute_test_value', AddConfigVar('compute_test_value',
"If 'True', Theano will run each op at graph build time, using Constants, SharedVariables and the tag 'test_value' as inputs to the function. This helps the user track down problems in the graph before it gets optimized.", "If 'True', Theano will run each op at graph build time, using Constants, SharedVariables and the tag 'test_value' as inputs to the function. This helps the user track down problems in the graph before it gets optimized.",
EnumStr('False', 'True', 'warn', 'err'), EnumStr('off', 'ignore', 'warn', 'raise'),
in_c_key=False) in_c_key=False)
...@@ -10,6 +10,7 @@ from theano import config ...@@ -10,6 +10,7 @@ from theano import config
import graph import graph
import numpy import numpy
import utils import utils
import warnings
import logging import logging
from theano import config from theano import config
from env import Env from env import Env
...@@ -327,7 +328,7 @@ class PureOp(object): ...@@ -327,7 +328,7 @@ class PureOp(object):
node = self.make_node(*inputs, **kwargs) node = self.make_node(*inputs, **kwargs)
self.add_tag_trace(node) self.add_tag_trace(node)
if config.compute_test_value != 'False': if config.compute_test_value != 'off':
# avoid circular import # avoid circular import
from theano.compile.sharedvalue import SharedVariable from theano.compile.sharedvalue import SharedVariable
run_perform = True run_perform = True
...@@ -345,14 +346,15 @@ class PureOp(object): ...@@ -345,14 +346,15 @@ class PureOp(object):
else: else:
# no test-value was specified, act accordingly # no test-value was specified, act accordingly
if config.compute_test_value == 'warn': if config.compute_test_value == 'warn':
# TODO: use warnings.warn, http://docs.python.org/library/warnings.html#warnings.warn warnings.warn('Warning, Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node), stacklevel=2)
print >>sys.stderr, ('Warning, Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node))
run_perform = False run_perform = False
elif config.compute_test_value == 'err': elif config.compute_test_value == 'raise':
raise ValueError('Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node)) raise ValueError('Cannot compute test value: input %i (%s) of Op %s missing default value' % (i, ins, node))
else: elif config.compute_test_value == 'ignore':
# silently skip test # silently skip test
run_perform = False run_perform = False
else:
raise ValueError('%s is invalid for option config.compute_Test_value' % config.compute_test_value)
# if all inputs have test-values, run the actual op # if all inputs have test-values, run the actual op
if run_perform: if run_perform:
...@@ -371,10 +373,8 @@ class PureOp(object): ...@@ -371,10 +373,8 @@ class PureOp(object):
# for a certain Op. # for a certain Op.
#TODO: use the c_thunk? #TODO: use the c_thunk?
if config.compute_test_value == 'warn': if config.compute_test_value == 'warn':
# TODO: use warnings.warn warnings.warn('Warning, in compute_test_value:' + type(e), stacklevel=2)
print >>sys.stderr, 'Warning, in compute_test_value:', type(e) elif config.compute_test_value == 'raise':
print >>sys.stderr, e
elif config.compute_test_value == 'err':
raise raise
if self.default_output is not None: if self.default_output is not None:
......
...@@ -2,9 +2,11 @@ import numpy ...@@ -2,9 +2,11 @@ import numpy
import unittest import unittest
import theano import theano
import warnings
from theano import config from theano import config
from theano import tensor as T from theano import tensor as T
from theano.tensor.basic import _allclose from theano.tensor.basic import _allclose
from theano.scan_module import scan
class TestComputeTestValue(unittest.TestCase): class TestComputeTestValue(unittest.TestCase):
...@@ -12,7 +14,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -12,7 +14,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_variable_only(self): def test_variable_only(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = 'True' theano.config.compute_test_value = 'raise'
x = T.matrix('x') x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX) x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX)
...@@ -41,22 +43,25 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -41,22 +43,25 @@ class TestComputeTestValue(unittest.TestCase):
y.tag.test_value = numpy.random.rand(4,5).astype(config.floatX) y.tag.test_value = numpy.random.rand(4,5).astype(config.floatX)
# should skip computation of test value # should skip computation of test value
theano.config.compute_test_value = 'False' theano.config.compute_test_value = 'off'
z = T.dot(x,y) z = T.dot(x,y)
assert not hasattr(z.tag, 'test_value') assert not hasattr(z.tag, 'test_value')
# should fail one or another when flag is set # should fail when asked by user
theano.config.compute_test_value = 'warn' theano.config.compute_test_value = 'raise'
self.assertRaises(Warning, T.dot, x, y)
theano.config.compute_test_value = 'err'
self.assertRaises(ValueError, T.dot, x, y) self.assertRaises(ValueError, T.dot, x, y)
# test that a warning is raised if required
theano.config.compute_test_value = 'warn'
warnings.simplefilter('error', UserWarning)
self.assertRaises(UserWarning, T.dot, x, y)
finally: finally:
theano.config.compute_test_value = orig_compute_test_value theano.config.compute_test_value = orig_compute_test_value
def test_string_var(self): def test_string_var(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = 'True' theano.config.compute_test_value = 'raise'
x = T.matrix('x') x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX) x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX)
...@@ -85,7 +90,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -85,7 +90,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_shared(self): def test_shared(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = 'True' theano.config.compute_test_value = 'raise'
x = T.matrix('x') x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX) x.tag.test_value = numpy.random.rand(3,4).astype(config.floatX)
...@@ -106,7 +111,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -106,7 +111,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_ndarray(self): def test_ndarray(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = 'True' theano.config.compute_test_value = 'raise'
x = numpy.random.rand(2,3).astype(config.floatX) x = numpy.random.rand(2,3).astype(config.floatX)
y = theano.shared(numpy.random.rand(3,6).astype(config.floatX), 'y') y = theano.shared(numpy.random.rand(3,6).astype(config.floatX), 'y')
...@@ -126,7 +131,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -126,7 +131,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_constant(self): def test_constant(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = 'True' theano.config.compute_test_value = 'raise'
x = T.constant(numpy.random.rand(2,3), dtype=config.floatX) x = T.constant(numpy.random.rand(2,3), dtype=config.floatX)
y = theano.shared(numpy.random.rand(3,6).astype(config.floatX), 'y') y = theano.shared(numpy.random.rand(3,6).astype(config.floatX), 'y')
...@@ -146,7 +151,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -146,7 +151,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_incorrect_type(self): def test_incorrect_type(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = 'True' theano.config.compute_test_value = 'raise'
x = T.fmatrix('x') x = T.fmatrix('x')
# Incorrect dtype (float64) for test_value # Incorrect dtype (float64) for test_value
...@@ -157,3 +162,32 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -157,3 +162,32 @@ class TestComputeTestValue(unittest.TestCase):
self.assertRaises(TypeError, T.dot, x, y) self.assertRaises(TypeError, T.dot, x, y)
finally: finally:
theano.config.compute_test_value = orig_compute_test_value theano.config.compute_test_value = orig_compute_test_value
def notest_scan(self):
"""
Do not run this test as the compute_test_value mechanism is known not to work with Scan.
TODO: fix scan to work with compute_test_value
"""
orig_compute_test_value = theano.config.compute_test_value
try:
theano.config.compute_test_value = 'raise'
k = T.iscalar("k")
A = T.vector("A")
k.tag.test_value = 3
A.tag.test_value = numpy.random.rand(5)
def fx(prior_result, A):
return prior_results * A
# Symbolic description of the result
result, updates = theano.scan(fn=lambda prior_result, A: prior_result * A,
outputs_info=T.ones_like(A),
non_sequences=A,
n_steps=k)
# We only care about A**k, but scan has provided us with A**1 through A**k.
# Discard the values that we don't care about. Scan is smart enough to
# notice this and not waste memory saving them.
final_result = result[-1]
assert hasattr(final_result.tag, 'test_value')
finally:
theano.config.compute_test_value = orig_compute_test_value
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论