提交 acd4a90f authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5643 from nouiz/assert_shape_default

Change the default of conv.assert_shape flags to speed up compilation time
...@@ -695,10 +695,9 @@ import theano and print the config variable, as in: ...@@ -695,10 +695,9 @@ import theano and print the config variable, as in:
.. attribute:: config.conv.assert_shape .. attribute:: config.conv.assert_shape
If False, AbstractConv* ops won't add assert that verify that If True, AbstractConv* ops will verify that user-provided
the user provided shapes are also the one at run time. shapes match the runtime shapes (debugging option,
may slow down compilation)
This can speed up compilation time and/or execution time.
.. attribute:: config.dnn.conv.workmem .. attribute:: config.dnn.conv.workmem
......
...@@ -128,9 +128,10 @@ AddConfigVar( ...@@ -128,9 +128,10 @@ AddConfigVar(
AddConfigVar( AddConfigVar(
'conv.assert_shape', 'conv.assert_shape',
"If False, AbstractConv* ops won't add assert that verify that" "If True, AbstractConv* ops will verify that user-provided"
" the user provided shapes are also the one at run time", " shapes match the runtime shapes (debugging option,"
BoolParam(True), " may slow down compilation)",
BoolParam(False),
in_c_key=False) in_c_key=False)
AddConfigVar( AddConfigVar(
......
...@@ -98,15 +98,17 @@ class change_flags(object): ...@@ -98,15 +98,17 @@ class change_flags(object):
Useful during tests. Useful during tests.
""" """
def __init__(self, **kwargs): def __init__(self, args=(), **kwargs):
confs = dict() confs = dict()
for k in kwargs: args = dict(args)
args.update(kwargs)
for k in args:
l = [v for v in theano.configparser._config_var_list l = [v for v in theano.configparser._config_var_list
if v.fullname == k] if v.fullname == k]
assert len(l) == 1 assert len(l) == 1
confs[k] = l[0] confs[k] = l[0]
self.confs = confs self.confs = confs
self.new_vals = kwargs self.new_vals = args
def __call__(self, f): def __call__(self, f):
@wraps(f) @wraps(f)
......
...@@ -217,7 +217,8 @@ def op_lifter(OP, cuda_only=False): ...@@ -217,7 +217,8 @@ def op_lifter(OP, cuda_only=False):
if (not replace or if (not replace or
(cuda_only and (cuda_only and
get_context(context_name).kind != b'cuda') or get_context(context_name).kind != b'cuda') or
any(["complex" in i.dtype for i in node.inputs])): any(["complex" in getattr(i, 'dtype', "")
for i in node.inputs])):
return False return False
# tag the inputs with the context in case # tag the inputs with the context in case
......
...@@ -7,6 +7,7 @@ from nose.tools import assert_raises, assert_true ...@@ -7,6 +7,7 @@ from nose.tools import assert_raises, assert_true
import theano import theano
from theano import tensor from theano import tensor
from theano.configparser import change_flags
from theano.gof.opt import check_stack_trace from theano.gof.opt import check_stack_trace
from theano.tests import unittest_tools as utt from theano.tests import unittest_tools as utt
from theano.tensor.nnet import (corr, corr3d, conv2d_transpose, from theano.tensor.nnet import (corr, corr3d, conv2d_transpose,
...@@ -229,6 +230,7 @@ class TestAssertConvShape(unittest.TestCase): ...@@ -229,6 +230,7 @@ class TestAssertConvShape(unittest.TestCase):
class TestAssertShape(unittest.TestCase): class TestAssertShape(unittest.TestCase):
@change_flags([("conv.assert_shape", True)])
def test_basic(self): def test_basic(self):
x = tensor.tensor4() x = tensor.tensor4()
s1 = tensor.iscalar() s1 = tensor.iscalar()
...@@ -244,6 +246,7 @@ class TestAssertShape(unittest.TestCase): ...@@ -244,6 +246,7 @@ class TestAssertShape(unittest.TestCase):
assert_raises(AssertionError, f, v, 0, 7) assert_raises(AssertionError, f, v, 0, 7)
assert_raises(AssertionError, f, v, 7, 7) assert_raises(AssertionError, f, v, 7, 7)
@change_flags([("conv.assert_shape", True)])
def test_shape_check_conv2d(self): def test_shape_check_conv2d(self):
input = tensor.tensor4() input = tensor.tensor4()
filters = tensor.tensor4() filters = tensor.tensor4()
...@@ -261,6 +264,7 @@ class TestAssertShape(unittest.TestCase): ...@@ -261,6 +264,7 @@ class TestAssertShape(unittest.TestCase):
numpy.zeros((3, 5, 7, 11), dtype='float32'), numpy.zeros((3, 5, 7, 11), dtype='float32'),
numpy.zeros((7, 5, 2, 2), dtype='float32')) numpy.zeros((7, 5, 2, 2), dtype='float32'))
@change_flags([("conv.assert_shape", True)])
def test_shape_check_conv3d(self): def test_shape_check_conv3d(self):
input = tensor.tensor5() input = tensor.tensor5()
filters = tensor.tensor5() filters = tensor.tensor5()
...@@ -278,6 +282,7 @@ class TestAssertShape(unittest.TestCase): ...@@ -278,6 +282,7 @@ class TestAssertShape(unittest.TestCase):
numpy.zeros((3, 5, 7, 11, 13), dtype='float32'), numpy.zeros((3, 5, 7, 11, 13), dtype='float32'),
numpy.zeros((7, 5, 2, 2, 2), dtype='float32')) numpy.zeros((7, 5, 2, 2, 2), dtype='float32'))
@change_flags([("conv.assert_shape", True)])
def test_shape_check_conv2d_grad_wrt_inputs(self): def test_shape_check_conv2d_grad_wrt_inputs(self):
output_grad = tensor.tensor4() output_grad = tensor.tensor4()
filters = tensor.tensor4() filters = tensor.tensor4()
...@@ -291,6 +296,7 @@ class TestAssertShape(unittest.TestCase): ...@@ -291,6 +296,7 @@ class TestAssertShape(unittest.TestCase):
numpy.zeros((3, 6, 5, 9), dtype='float32'), numpy.zeros((3, 6, 5, 9), dtype='float32'),
numpy.zeros((7, 6, 3, 3), dtype='float32')) numpy.zeros((7, 6, 3, 3), dtype='float32'))
@change_flags([("conv.assert_shape", True)])
def test_shape_check_conv3d_grad_wrt_inputs(self): def test_shape_check_conv3d_grad_wrt_inputs(self):
output_grad = tensor.tensor5() output_grad = tensor.tensor5()
filters = tensor.tensor5() filters = tensor.tensor5()
...@@ -304,6 +310,7 @@ class TestAssertShape(unittest.TestCase): ...@@ -304,6 +310,7 @@ class TestAssertShape(unittest.TestCase):
numpy.zeros((3, 6, 5, 9, 11), dtype='float32'), numpy.zeros((3, 6, 5, 9, 11), dtype='float32'),
numpy.zeros((7, 6, 3, 3, 3), dtype='float32')) numpy.zeros((7, 6, 3, 3, 3), dtype='float32'))
@change_flags([("conv.assert_shape", True)])
def test_shape_check_conv2d_grad_wrt_weights(self): def test_shape_check_conv2d_grad_wrt_weights(self):
input = tensor.tensor4() input = tensor.tensor4()
output_grad = tensor.tensor4() output_grad = tensor.tensor4()
...@@ -317,6 +324,7 @@ class TestAssertShape(unittest.TestCase): ...@@ -317,6 +324,7 @@ class TestAssertShape(unittest.TestCase):
numpy.zeros((3, 6, 7, 11), dtype='float32'), numpy.zeros((3, 6, 7, 11), dtype='float32'),
numpy.zeros((3, 7, 5, 9), dtype='float32')) numpy.zeros((3, 7, 5, 9), dtype='float32'))
@change_flags([("conv.assert_shape", True)])
def test_shape_check_conv3d_grad_wrt_weights(self): def test_shape_check_conv3d_grad_wrt_weights(self):
input = tensor.tensor5() input = tensor.tensor5()
output_grad = tensor.tensor5() output_grad = tensor.tensor5()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论