提交 bb964aaf authored 作者: Nicolas Ballas's avatar Nicolas Ballas

Add optimization that verifies that no AbstractConv are present in the graph

上级 4022b2ae
......@@ -4,6 +4,7 @@ Optimizations addressing the ops in nnet root directory
import theano
from theano import compile, gof
from theano.compile import optdb
from theano.gof import local_optimizer
from theano.tensor.nnet.corr import (
......@@ -20,6 +21,7 @@ from theano.tensor.nnet.abstract_conv import get_conv_output_shape
from theano.tensor.opt import (copy_stack_trace,
register_specialize_device)
from theano.tensor import TensorType
from theano.tensor import opt
# Cpu implementation
from theano.tensor.nnet.conv import conv2d, ConvOp
......@@ -379,3 +381,30 @@ conv_groupopt.register('local_conv2d_gradweight_cpu',
conv_groupopt.register('local_conv2d_gradinputs_cpu',
local_conv2d_gradinputs_cpu, 40,
'fast_compile', 'fast_run')
# Verify that no AbstractConv are present in the graph
@local_optimizer([AbstractConv2d,
AbstractConv2d_gradWeights,
AbstractConv2d_gradInputs])
def local_abstractconv_check(node):
if isinstance(node.op, AbstractConv2d):
raise AssertionError(
'AbstractConv2d theano optimization failed. '
'Did you exclude both "conv_dnn" and "conv_gemm" from '
'the optimizer? Is cudnn available and does the GPU support it?')
elif isinstance(node.op, AbstractConv2d_gradWeights):
raise AssertionError(
'AbstractConv2d_gradWeights theano optimization failed. '
'Did you exclude both "conv_dnn" and "conv_gemm" from '
'the optimizer? Is cudnn available and does the GPU support it?')
elif isinstance(node.op, AbstractConv2d_gradInputs):
raise AssertionError(
'AbstractConv2d_gradInputs theano optimization failed. '
'Did you exclude both "conv_dnn" and "conv_gemm" from '
'the optimizer? Is cudnn available and does the GPU support it?')
optdb.register('AbstracConvCheck',
opt.in2out(local_abstractconv_check,
name="AbstractConvCheck"),
48.7, 'fast_compile', 'fast_run')
......@@ -312,7 +312,7 @@ class TestCpuConv2d(BaseTestConv2d):
mode=mode, provide_shape=provide_shape,
border_mode=b, filter_flip=flip, target_op=ConvOp)
else:
self.assertRaises(NotImplementedError,
self.assertRaises(AssertionError,
self.run_fwd,
inputs_shape=i,
filters_shape=f,
......@@ -331,7 +331,7 @@ class TestCpuConv2d(BaseTestConv2d):
filter_flip=flip,
target_op=(ConvOp, ConvGrad3D))
else:
self.assertRaises(NotImplementedError,
self.assertRaises(AssertionError,
self.run_gradweight,
inputs_shape=i,
filters_shape=f,
......@@ -351,7 +351,7 @@ class TestCpuConv2d(BaseTestConv2d):
filter_flip=flip,
target_op=(ConvOp, ConvTransp3D))
else:
self.assertRaises(NotImplementedError,
self.assertRaises(AssertionError,
self.run_gradinput,
inputs_shape=i,
filters_shape=f,
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论