提交 02d11f7d authored 作者: Frédéric Bastien's avatar Frédéric Bastien 提交者: GitHub

Merge pull request #5635 from nouiz/comp_opt

[ENH] for rc3, Graph clean up, faster compilation, conv.assert_shape Theano flags, add unsafe in useless opt
......@@ -693,6 +693,13 @@ import theano and print the config variable, as in:
If ``'False'``, do not use cuDNN or check if it is available.
.. attribute:: config.conv.assert_shape
If False, AbstractConv* ops won't add assert that verify that
the user provided shapes are also the one at run time.
This can speed up compilation time and/or execution time.
.. attribute:: config.dnn.conv.workmem
Deprecated, use :attr:`config.dnn.conv.algo_fwd`.
......
......@@ -126,6 +126,13 @@ AddConfigVar(
BoolParam(False, allow_override=False),
in_c_key=False)
AddConfigVar(
'conv.assert_shape',
"If False, AbstractConv* ops won't add assert that verify that"
" the user provided shapes are also the one at run time",
BoolParam(True),
in_c_key=False)
AddConfigVar(
'print_global_stats',
"Print some global statistics (time spent) at the end",
......
......@@ -607,11 +607,20 @@ class MergeFeature(object):
# properly.
# The clients should at least contain `node` itself!
if node.inputs:
assert len(node.inputs[0].clients) > 0
assert (node, 0) in node.inputs[0].clients
# Take the smallest clients list. Some ops like elemwise
# have optimization that put constant as the first inputs.
# As constant have in general more clients than other type of nodes
# using always inputs[0] make us look at more nodes.
# Always pick the smallest clints list between inputs 0
# and -1 speed up optimization.
if len(node.inputs[0].clients) < len(node.inputs[-1].clients):
clients = node.inputs[0].clients
else:
clients = node.inputs[-1].clients
assert len(clients) > 0
merge_candidates = [c for (c, i) in node.inputs[0].clients
if c in self.nodes_seen]
merge_candidates = [c for c, i in clients if c in self.nodes_seen]
# Put all clients of Assert inputs (if exist) into merge_candidates
# TODO: Deactivated for now as this cause cycle in the graph.
......
......@@ -270,13 +270,14 @@ class Kernel(object):
def get_ctype(dtype):
if dtype is gpuarray.GpuArray:
return "gpudata *"
elif isinstance(dtype, np.dtype):
return 'npy_' + dtype.name
elif dtype == gpuarray.SIZE:
return "size_t"
elif dtype == gpuarray.SSIZE:
return "ssize_t"
else:
if not isinstance(dtype, np.dtype):
dtype = np.dtype(dtype)
dtype = np.dtype(dtype)
return 'npy_' + dtype.name
......
......@@ -1353,4 +1353,6 @@ def forced_replace(out, x, y):
elif graph.owner:
q.extendleft(graph.owner.inputs)
if len(to_replace) == 0:
return out
return clone(out, replace=to_replace)
......@@ -493,7 +493,7 @@ def assert_shape(x, expected_shape, msg='Unexpected shape.'):
will return `x` directly.
"""
if expected_shape is None:
if expected_shape is None or not theano.config.conv.assert_shape:
return x
shape = x.shape
tests = []
......@@ -1680,19 +1680,20 @@ class AbstractConv2d(AbstractConv):
def grad(self, inp, grads):
bottom, weights = inp
top, = grads
# Don't add the assert again, as it was already added in the forward.
d_bottom = AbstractConv2d_gradInputs(self.imshp, self.kshp,
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(
weights, top, bottom.shape[-2:])
weights, top, bottom.shape[-2:], add_assert_shape=False)
d_weights = AbstractConv2d_gradWeights(self.imshp, self.kshp,
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(
bottom, top, weights.shape[-2:])
bottom, top, weights.shape[-2:], add_assert_shape=False)
# Make sure that the broadcastable pattern of the inputs is used
# for the gradients, even if the grad opts are not able to infer
......@@ -1781,7 +1782,7 @@ class AbstractConv_gradWeights(BaseAbstractConv):
filter_dilation=filter_dilation)
# Update shape/height_width
def make_node(self, img, topgrad, shape):
def make_node(self, img, topgrad, shape, add_assert_shape=True):
# Make sure both inputs are Variables with the same Type
if not isinstance(img, theano.Variable):
img = as_tensor_variable(img)
......@@ -1795,10 +1796,10 @@ class AbstractConv_gradWeights(BaseAbstractConv):
raise TypeError('img must be %dD tensor' % (2 + self.convdim))
if topgrad.type.ndim != 2 + self.convdim:
raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim))
img = assert_shape(img, self.imshp,
'AbstractConv_gradWeights shape mismatch: shape of '
'image does not match given imshp.')
if add_assert_shape:
img = assert_shape(img, self.imshp,
'AbstractConv_gradWeights shape mismatch: shape of '
'image does not match given imshp.')
shape = as_tensor_variable(shape)
broadcastable = [topgrad.broadcastable[1],
......@@ -2020,7 +2021,7 @@ class AbstractConv_gradInputs(BaseAbstractConv):
filter_dilation=filter_dilation)
# Update shape/height_width
def make_node(self, kern, topgrad, shape):
def make_node(self, kern, topgrad, shape, add_assert_shape=True):
# Make sure both inputs are Variables with the same Type
if not isinstance(kern, theano.Variable):
kern = as_tensor_variable(kern)
......@@ -2035,9 +2036,10 @@ class AbstractConv_gradInputs(BaseAbstractConv):
if topgrad.type.ndim != 2 + self.convdim:
raise TypeError('topgrad must be %dD tensor' % (2 + self.convdim))
kern = assert_shape(kern, self.kshp,
'AbstractConv_gradInputs shape mismatch: shape of '
'filters does not match given kshp.')
if add_assert_shape:
kern = assert_shape(kern, self.kshp,
'AbstractConv_gradInputs shape mismatch: shape of '
'filters does not match given kshp.')
shape = as_tensor_variable(shape)
broadcastable = [topgrad.type.broadcastable[0],
......@@ -2158,8 +2160,9 @@ class AbstractConv2d_gradInputs(AbstractConv_gradInputs):
self.border_mode,
self.subsample,
self.filter_flip,
self.filter_dilation)(bottom, top,
weights.shape[-2:])
self.filter_dilation)(
bottom, top,
weights.shape[-2:])
d_top = AbstractConv2d(self.imshp, self.kshp,
self.border_mode,
self.subsample,
......
......@@ -2418,6 +2418,10 @@ compile.optdb['specialize'].register('local_remove_all_assert',
local_remove_all_assert,
'unsafe',
use_db_name_as_tag=False)
compile.optdb['useless'].register('local_remove_all_assert',
local_remove_all_assert,
'unsafe',
use_db_name_as_tag=False)
#######################
# Constant Canonicalization
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论