提交 1cfa8fbc authored 作者: James Bergstra's avatar James Bergstra

ConvOp - modified grad set unrolling params to None, and let __init__ set them.

上级 9e4f9e45
......@@ -10,13 +10,14 @@ See especially conv2d().
__docformat__ = "restructuredtext en"
import sys, logging
import numpy
import theano
import theano.tensor as tensor
from theano import gof, Op, tensor, config
from theano.tensor.tsor_apply import Apply
from theano.gof.python25 import any
import logging
_logger=logging.getLogger("theano.signal.conv")
def _debug(*msg):
_logger.debug(' '.join([ str(x) for x in msg]))
......@@ -508,8 +509,8 @@ class ConvOp(Op):
self.unroll_kern=self.speed_unroll_batch_kern[time_unroll_batch_kern_idx][1]
self.unroll_patch = False
print "AUTO FIND VERSION OF C_CODE OF CONV OP"
print self.unroll_batch, self.unroll_kern, self.unroll_patch, self.bsize, self.nkern, time_unroll_patch, time_unroll_batch_kern
print >> sys.stderr, "AUTO FIND VERSION OF C_CODE OF CONV OP"
print >> sys.stderr, self.unroll_batch, self.unroll_kern, self.unroll_patch, self.bsize, self.nkern, time_unroll_patch, time_unroll_batch_kern
self._rehash()
......@@ -772,30 +773,41 @@ class ConvOp(Op):
filters = filters[:,:,::-1,::-1] #flip them
#find good value for the unroll
if all_shape and un_b!=0 and bsize%un_b!=0:
if bsize<un_b:
un_b = bsize
else:
un_b = 1
_warn("OPTIMISATION WARNING: in ConvOp.grad() we can't determine "\
"a good unroll value for the batch. Maybe you can optimize this!")
if 0: #find good value for the unroll
if all_shape and un_b!=0 and bsize%un_b!=0:
if bsize<un_b:
un_b = bsize
else:
un_b = 1
_warn("OPTIMISATION WARNING: in ConvOp.grad() we can't determine "\
"a good unroll value for the batch. Maybe you can optimize this!")
if all_shape and un_k!=0 and nkern%un_k!=0:
if nkern<un_k:
un_k = nkern
else:
un_k = 1
_warn("OPTIMISATION WARNING: in ConvOp.grad() we can't determine "\
"a good unroll value for the kernel. Maybe you can optimize this!")
dw = ConvOp(imshp, kshp, nkern, bsize, 1,1, output_mode='valid',
unroll_batch=un_b, unroll_kern=un_k, unroll_patch=un_p,
imshp_logical=imshp_logical,
kshp_logical=kshp_logical,
kshp_logical_top_aligned=kshp_logical_top_aligned,
version=self.version,
verbose=self.verbose)
else: # let __init__ choose c params be chosen automatically from shapes
dw = ConvOp(imshp, kshp, nkern, bsize, 1,1, output_mode='valid',
unroll_batch=None, unroll_kern=None, unroll_patch=None,
imshp_logical=imshp_logical,
kshp_logical=kshp_logical,
kshp_logical_top_aligned=kshp_logical_top_aligned,
version=self.version,
verbose=self.verbose)
if all_shape and un_k!=0 and nkern%un_k!=0:
if nkern<un_k:
un_k = nkern
else:
un_k = 1
_warn("OPTIMISATION WARNING: in ConvOp.grad() we can't determine "\
"a good unroll value for the kernel. Maybe you can optimize this!")
dw = ConvOp(imshp, kshp, nkern, bsize, 1,1, output_mode='valid',
unroll_batch=un_b, unroll_kern=un_k, unroll_patch=un_p,
imshp_logical=imshp_logical,
kshp_logical=kshp_logical,
kshp_logical_top_aligned=kshp_logical_top_aligned,
version=self.version,
verbose=self.verbose)
if hasattr(self,'flops'):
dw.set_flops()
......@@ -826,13 +838,22 @@ class ConvOp(Op):
imshp = (self.nkern, self.outshp[0], self.outshp[1])
imshp_logical=(self.nkern, self.fulloutshp[0], self.fulloutshp[1])
din = ConvOp(imshp, self.kshp, nkern, self.bsize,
1,1, output_mode=mode,
unroll_batch=un_b, unroll_kern=un_k, unroll_patch=un_p,
imshp_logical=imshp_logical,
kshp_logical=None,
version=-1,#we we change the mode, we don't forward the version.
verbose=self.verbose)
if 0: # hard-code c generation parameters
din = ConvOp(imshp, self.kshp, nkern, self.bsize,
1,1, output_mode=mode,
unroll_batch=un_b, unroll_kern=un_k, unroll_patch=un_p,
imshp_logical=imshp_logical,
kshp_logical=None,
version=-1,#we we change the mode, we don't forward the version.
verbose=self.verbose)
else: # let __init__ figure out the unrolling / patch sizes
din = ConvOp(imshp, self.kshp, nkern, self.bsize,
1,1, output_mode=mode,
unroll_batch=None, unroll_kern=None, unroll_patch=None,
imshp_logical=imshp_logical,
kshp_logical=None,
version=-1,#we we change the mode, we don't forward the version.
verbose=self.verbose)
if hasattr(self,'flops'):
din.set_flops()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论