提交 4c70c203 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make the default for workmen a theano flag.

上级 612d2772
import os import os
import theano import theano
from theano import Apply, gof, tensor from theano import Apply, gof, tensor, config
from theano.scalar import as_scalar from theano.scalar import as_scalar
from theano.gradient import DisconnectedType from theano.gradient import DisconnectedType
from theano.gof import Optimizer, local_optimizer, COp from theano.gof import Optimizer, local_optimizer, COp
from theano.gof.type import CDataType, Generic from theano.gof.type import CDataType, Generic
from theano.compat import PY3 from theano.compat import PY3
from theano.compile.ops import shape_i from theano.compile.ops import shape_i
from theano.configparser import AddConfigVar, EnumStr
from theano.tensor.nnet import SoftmaxGrad from theano.tensor.nnet import SoftmaxGrad
from theano.tensor.basic import ShapeError from theano.tensor.basic import ShapeError
from theano.sandbox.cuda.type import CudaNdarrayType from theano.sandbox.cuda.type import CudaNdarrayType
...@@ -328,6 +329,11 @@ class GpuDnnConvDesc(GpuOp): ...@@ -328,6 +329,11 @@ class GpuDnnConvDesc(GpuOp):
return (2, version()) return (2, version())
AddConfigVar('dnn.conv.workmem',
"Default value for the workmem attribute of cudnn convolutions.",
EnumStr('small', 'none', 'large'),
in_c_key=False)
class GpuDnnConv(DnnBase, COp): class GpuDnnConv(DnnBase, COp):
""" """
The forward convolution. The forward convolution.
...@@ -338,12 +344,14 @@ class GpuDnnConv(DnnBase, COp): ...@@ -338,12 +344,14 @@ class GpuDnnConv(DnnBase, COp):
""" """
__props__ = ('workmem',) __props__ = ('workmem',)
def __init__(self, workmem='small'): def __init__(self, workmem=None):
""" """
:param workmem: either 'none', 'small' or 'large'. Default is 'small'. :param workmem: either 'none', 'small' or 'large'. Default is 'small'.
""" """
COp.__init__(self, ["dnn_base.c", "dnn_conv_base.c", "dnn_fwd.c"], COp.__init__(self, ["dnn_base.c", "dnn_conv_base.c", "dnn_fwd.c"],
"APPLY_SPECIFIC(conv_fwd)") "APPLY_SPECIFIC(conv_fwd)")
if workmem is None:
workmem = config.dnn.conv.workmem
self.workmem = workmem self.workmem = workmem
assert self.workmem in ['none', 'small', 'large'] assert self.workmem in ['none', 'small', 'large']
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论