提交 3e1516a9 authored 作者: Pascal Lamblin's avatar Pascal Lamblin

Change the values for config.compute_test_value to be all strings

上级 fde8856f
...@@ -202,7 +202,5 @@ AddConfigVar('warn.sum_div_dimshuffle_bug', ...@@ -202,7 +202,5 @@ AddConfigVar('warn.sum_div_dimshuffle_bug',
BoolParam(default_0_3)) BoolParam(default_0_3))
AddConfigVar('compute_test_value', AddConfigVar('compute_test_value',
"If True, Theano will run each op at graph build time, using Constants, SharedVariables and the tag 'test_value' as inputs to the function. This helps the user track down problems in the graph before it gets optimized.", "If 'True', Theano will run each op at graph build time, using Constants, SharedVariables and the tag 'test_value' as inputs to the function. This helps the user track down problems in the graph before it gets optimized.",
EnumStr(False, True, 'warn', 'err')) EnumStr('False', 'True', 'warn', 'err'))
...@@ -324,7 +324,7 @@ class PureOp(object): ...@@ -324,7 +324,7 @@ class PureOp(object):
node = self.make_node(*inputs, **kwargs) node = self.make_node(*inputs, **kwargs)
self.add_tag_trace(node) self.add_tag_trace(node)
if config.compute_test_value: if config.compute_test_value != 'False':
# avoid circular import # avoid circular import
from theano.compile.sharedvalue import SharedVariable from theano.compile.sharedvalue import SharedVariable
run_perform = True run_perform = True
......
...@@ -11,7 +11,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -11,7 +11,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_variable_only(self): def test_variable_only(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = True theano.config.compute_test_value = 'True'
x = T.matrix('x') x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4) x.tag.test_value = numpy.random.rand(3,4)
...@@ -40,7 +40,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -40,7 +40,7 @@ class TestComputeTestValue(unittest.TestCase):
y.tag.test_value = numpy.random.rand(4,5) y.tag.test_value = numpy.random.rand(4,5)
# should skip computation of test value # should skip computation of test value
theano.config.compute_test_value = False theano.config.compute_test_value = 'False'
z = T.dot(x,y) z = T.dot(x,y)
assert not hasattr(z.tag, 'test_value') assert not hasattr(z.tag, 'test_value')
...@@ -55,7 +55,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -55,7 +55,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_string_var(self): def test_string_var(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = True theano.config.compute_test_value = 'True'
x = T.matrix('x') x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4) x.tag.test_value = numpy.random.rand(3,4)
...@@ -84,7 +84,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -84,7 +84,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_shared(self): def test_shared(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = True theano.config.compute_test_value = 'True'
x = T.matrix('x') x = T.matrix('x')
x.tag.test_value = numpy.random.rand(3,4) x.tag.test_value = numpy.random.rand(3,4)
...@@ -105,7 +105,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -105,7 +105,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_ndarray(self): def test_ndarray(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = True theano.config.compute_test_value = 'True'
x = numpy.random.rand(2,3) x = numpy.random.rand(2,3)
y = theano.shared(numpy.random.rand(3,6), 'y') y = theano.shared(numpy.random.rand(3,6), 'y')
...@@ -125,7 +125,7 @@ class TestComputeTestValue(unittest.TestCase): ...@@ -125,7 +125,7 @@ class TestComputeTestValue(unittest.TestCase):
def test_constant(self): def test_constant(self):
orig_compute_test_value = theano.config.compute_test_value orig_compute_test_value = theano.config.compute_test_value
try: try:
theano.config.compute_test_value = True theano.config.compute_test_value = 'True'
x = T.constant(numpy.random.rand(2,3)) x = T.constant(numpy.random.rand(2,3))
y = theano.shared(numpy.random.rand(3,6), 'y') y = theano.shared(numpy.random.rand(3,6), 'y')
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论