提交 acad83a3 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Make sure not to use the descriptor op if not available.

上级 171aa1d7
...@@ -1406,11 +1406,12 @@ class GpuDnnPool(DnnBase): ...@@ -1406,11 +1406,12 @@ class GpuDnnPool(DnnBase):
or desc.type.ctype != 'cudnnPoolingDescriptor_t': or desc.type.ctype != 'cudnnPoolingDescriptor_t':
raise TypeError('desc must be cudnnPoolingDescriptor_t') raise TypeError('desc must be cudnnPoolingDescriptor_t')
dop = desc.owner.op if desc.owner is not None:
e_ndim = dop.get_ndim() + 2 # 4 or 5 dop = desc.owner.op
e_ndim = dop.get_ndim() + 2 # 4 or 5
if img.type.ndim != e_ndim: if img.type.ndim != e_ndim:
raise TypeError('img must be %dD tensor' % e_ndim) raise TypeError('img must be %dD tensor' % e_ndim)
return Apply(self, [img, desc], [img.type()]) return Apply(self, [img, desc], [img.type()])
...@@ -1572,19 +1573,21 @@ class GpuDnnPoolGrad(DnnBase): ...@@ -1572,19 +1573,21 @@ class GpuDnnPoolGrad(DnnBase):
or desc.type.ctype != 'cudnnPoolingDescriptor_t': or desc.type.ctype != 'cudnnPoolingDescriptor_t':
raise TypeError('desc must be cudnnPoolingDescriptor_t') raise TypeError('desc must be cudnnPoolingDescriptor_t')
nd = desc.owner.op.get_ndim() + 2 # 4 or 5
inp = as_cuda_ndarray_variable(inp) inp = as_cuda_ndarray_variable(inp)
if inp.type.ndim != nd:
raise TypeError('inp must be %dD tensor' % (nd,))
inp_grad = as_cuda_ndarray_variable(inp_grad) inp_grad = as_cuda_ndarray_variable(inp_grad)
if inp_grad.type.ndim != nd:
raise TypeError('inp_grad must be %dD tensor' % (nd,))
out = as_cuda_ndarray_variable(out) out = as_cuda_ndarray_variable(out)
if out.type.ndim != nd:
raise TypeError('out must be %dD tensor' % (nd,)) if desc.owner is not None:
nd = desc.owner.op.get_ndim() + 2 # 4 or 5
if inp.type.ndim != nd:
raise TypeError('inp must be %dD tensor' % (nd,))
if inp_grad.type.ndim != nd:
raise TypeError('inp_grad must be %dD tensor' % (nd,))
if out.type.ndim != nd:
raise TypeError('out must be %dD tensor' % (nd,))
return Apply(self, [inp, out, inp_grad, desc], return Apply(self, [inp, out, inp_grad, desc],
[inp.type()]) [inp.type()])
......
...@@ -945,10 +945,11 @@ class GpuDnnPool(DnnBase): ...@@ -945,10 +945,11 @@ class GpuDnnPool(DnnBase):
def make_node(self, img, desc): def make_node(self, img, desc):
img = as_gpuarray_variable(img) img = as_gpuarray_variable(img)
e_ndim = desc.owner.op.get_ndim() + 2 if desc.owner is not None:
e_ndim = desc.owner.op.get_ndim() + 2
if img.type.ndim != e_ndim: if img.type.ndim != e_ndim:
raise TypeError('img must be %dD tensor' % (e_ndim,)) raise TypeError('img must be %dD tensor' % (e_ndim,))
if (not isinstance(desc.type, CDataType) or if (not isinstance(desc.type, CDataType) or
desc.type.ctype != 'cudnnPoolingDescriptor_t'): desc.type.ctype != 'cudnnPoolingDescriptor_t'):
...@@ -1010,19 +1011,21 @@ class GpuDnnPoolGrad(DnnBase): ...@@ -1010,19 +1011,21 @@ class GpuDnnPoolGrad(DnnBase):
"APPLY_SPECIFIC(dnn_pool_grad)") "APPLY_SPECIFIC(dnn_pool_grad)")
def make_node(self, inp, out, out_grad, desc): def make_node(self, inp, out, out_grad, desc):
nd = desc.owner.op.get_ndim() + 2
inp = as_gpuarray_variable(inp) inp = as_gpuarray_variable(inp)
if inp.type.ndim != nd:
raise TypeError('inp must be %dD tensor' % (nd,))
out_grad = as_gpuarray_variable(out_grad) out_grad = as_gpuarray_variable(out_grad)
if out_grad.type.ndim != nd:
raise TypeError('out_grad must be %dD tensor' % (nd,))
out = as_gpuarray_variable(out) out = as_gpuarray_variable(out)
if out.type.ndim != nd:
raise TypeError('out must be %dD tensor' % (nd,)) if desc.owner is not None:
nd = desc.owner.op.get_ndim() + 2
if inp.type.ndim != nd:
raise TypeError('inp must be %dD tensor' % (nd,))
if out_grad.type.ndim != nd:
raise TypeError('out_grad must be %dD tensor' % (nd,))
if out.type.ndim != nd:
raise TypeError('out must be %dD tensor' % (nd,))
if (not isinstance(desc.type, CDataType) or if (not isinstance(desc.type, CDataType) or
desc.type.ctype != 'cudnnPoolingDescriptor_t'): desc.type.ctype != 'cudnnPoolingDescriptor_t'):
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论