提交 7bbbe727 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Fix some things in code review.

上级 fae01b90
#section support_code_apply #section support_code_apply
int conv_desc(PyArrayObject *filt_shp, int APPLY_SPECIFIC(conv_desc)(PyArrayObject *filt_shp,
cudnnConvolutionDescriptor_t *desc) { cudnnConvolutionDescriptor_t *desc) {
cudnnStatus_t err; cudnnStatus_t err;
int pad[3] = {PAD_0, PAD_1, PAD_2}; int pad[3] = {PAD_0, PAD_1, PAD_2};
int strides[3] = {SUB_0, SUB_1, SUB_2}; int strides[3] = {SUB_0, SUB_1, SUB_2};
......
import os import os
import numpy import numpy
import warnings
import theano import theano
from theano import Op, Apply, tensor, config, Variable from theano import Op, Apply, tensor, config, Variable
...@@ -226,7 +227,7 @@ class GpuDnnConvDesc(COp): ...@@ -226,7 +227,7 @@ class GpuDnnConvDesc(COp):
return False return False
def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv'): def __init__(self, border_mode, subsample=(1, 1), conv_mode='conv'):
COp.__init__(self, ["conv_desc.c"], "conv_desc") COp.__init__(self, ["conv_desc.c"], "APPLY_SPECIFIC(conv_desc)")
if isinstance(border_mode, int): if isinstance(border_mode, int):
border_mode = (border_mode,) * len(subsample) border_mode = (border_mode,) * len(subsample)
...@@ -764,6 +765,11 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1), ...@@ -764,6 +765,11 @@ def dnn_conv(img, kerns, border_mode='valid', subsample=(1, 1),
work with this Op. work with this Op.
""" """
if workmem is not None:
if algo is not None:
raise ValueError("You can't use both algo and workmem")
warnings.warn("workmem is deprecated, use algo instead", stacklevel=2)
algo = workmem
fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None) fgraph = getattr(img, 'fgraph', None) or getattr(kerns, 'fgraph', None)
if (border_mode == 'valid' and subsample == (1, 1) and if (border_mode == 'valid' and subsample == (1, 1) and
direction_hint == 'bprop weights'): direction_hint == 'bprop weights'):
...@@ -821,18 +827,18 @@ class GpuDnnPoolDesc(Op): ...@@ -821,18 +827,18 @@ class GpuDnnPoolDesc(Op):
This Op builds a pooling descriptor for use in the other This Op builds a pooling descriptor for use in the other
pooling operations. pooling operations.
`ws`, `stride` and `pad` must have the same length.
Parameters Parameters
---------- ----------
ws ws : tuple
Windows size. Window size.
stride stride : tuple
(dx, dy). (dx, dy) or (dx, dy, dz).
mode : {'max', 'average_inc_pad', 'average_exc_pad'} mode : {'max', 'average_inc_pad', 'average_exc_pad'}
The old deprecated name 'average' corresponds to 'average_inc_pad'. The old deprecated name 'average' corresponds to 'average_inc_pad'.
pad pad : tuple
(padX, padY) padding information. (padX, padY) or (padX, padY, padZ)
padX is the size of the left and right borders,
padY is the size of the top and bottom borders.
""" """
...@@ -1044,19 +1050,20 @@ def dnn_pool(img, ws, stride=(1, 1), mode='max', pad=(0, 0)): ...@@ -1044,19 +1050,20 @@ def dnn_pool(img, ws, stride=(1, 1), mode='max', pad=(0, 0)):
The memory layout to use is 'bc01', that is 'batch', 'channel', The memory layout to use is 'bc01', that is 'batch', 'channel',
'first dim', 'second dim' in that order. 'first dim', 'second dim' in that order.
`ws`, `stride` and `pad` must have the same length.
Parameters Parameters
---------- ----------
img img
Images to do the pooling over. Images to do the pooling over.
ws ws : tuple
Subsampling window size. Subsampling window size.
stride stride : tuple
Subsampling stride (default: (1, 1)). Subsampling stride (default: (1, 1)).
mode : {'max', 'average_inc_pad', 'average_exc_pad'} mode : {'max', 'average_inc_pad', 'average_exc_pad'}
pad pad : tuple
(padX, padY) padding information. (padX, padY) or (padX, padY, padZ)
padX is the size of the left and right borders, default: (0, 0)
padY is the size of the top and bottom borders.
.. warning:: The cuDNN library only works with GPU that have a compute .. warning:: The cuDNN library only works with GPU that have a compute
capability of 3.0 or higer. This means that older GPU will not capability of 3.0 or higer. This means that older GPU will not
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论