提交 6b422288 authored 作者: Nicolas Ballas's avatar Nicolas Ballas 提交者: Pascal Lamblin

fix foward conv

上级 a9f805bd
...@@ -103,8 +103,10 @@ def conv2d(img, ...@@ -103,8 +103,10 @@ def conv2d(img,
if (filter_flip): if (filter_flip):
filters = filters[:, :, ::-1, ::-1] filters = filters[:, :, ::-1, ::-1]
### FIXME input shape/kernel shape ### FIXME input shape/kernel shape
conv_op = Conv2d(imshp=input_shape, kshp=filter_shape, bsize=batch_size, conv_op = AbstractConv2d(imshp=input_shape, kshp=filter_shape,
border_mode=border_mode, subsample=subsample) bsize=batch_size,
border_mode=border_mode,
subsample=subsample)
return conv_op(img, filters) return conv_op(img, filters)
...@@ -195,7 +197,7 @@ class AbstractConv2d(BaseAbstractConv2d): ...@@ -195,7 +197,7 @@ class AbstractConv2d(BaseAbstractConv2d):
kern.broadcastable[0], kern.broadcastable[0],
False, False] False, False]
output = img.type.__class__(dtype=img.type.dtype, output = img.type.__class__(dtype=img.type.dtype,
broadcastable=broadcastable) broadcastable=broadcastable)()
return Apply(self, [img, kern], [output]) return Apply(self, [img, kern], [output])
def perform(self, node, inp, out_): def perform(self, node, inp, out_):
...@@ -420,8 +422,8 @@ def local_conv2d_cudnn(node): ...@@ -420,8 +422,8 @@ def local_conv2d_cudnn(node):
inp1, inp2 = node.inputs inp1, inp2 = node.inputs
shape = None shape = None
if not isinstance(inp1, CudaNdarrayType) or \ if not isinstance(inp1.type, CudaNdarrayType) or \
isinstance(inp2, CudaNdarrayType): not isinstance(inp2.type, CudaNdarrayType):
return None return None
if not dnn_available(): if not dnn_available():
return None return None
...@@ -430,28 +432,30 @@ def local_conv2d_cudnn(node): ...@@ -430,28 +432,30 @@ def local_conv2d_cudnn(node):
border_mode=node.op.border_mode, border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
direction_hint='forward') direction_hint='forward')
return rval return [rval]
if (isinstance(node.op, AbstractConv2d_gradWeights)): if (isinstance(node.op, AbstractConv2d_gradWeights)):
rval = dnn_conv(inp1.dimshuffle(1, 0, 2, 3), inp2, rval = dnn_conv(inp1.dimshuffle(1, 0, 2, 3), inp2,
border_mode=node.op.border_mode, border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
direction_hint='bprop weights') direction_hint='bprop weights')
return rval return [rval]
if (isinstance(node.op, AbstractConv2d_gradInputs)): if (isinstance(node.op, AbstractConv2d_gradInputs)):
rval = dnn_conv(inp1, inp2, rval = dnn_conv(inp1, inp2,
border_mode=node.op.border_mode, border_mode=node.op.border_mode,
subsample=node.op.subsample, subsample=node.op.subsample,
direction_hint='bprop inputs') direction_hint='bprop inputs')
return rval return [rval]
register_specialize_device(local_conv2d_cudnn) register_specialize_device(local_conv2d_cudnn)
@local_optimizer([AbstractConv2d]) @local_optimizer([AbstractConv2d])
def local_conv2d_corrmm(convop, inputs): def local_conv2d_corrmm(node):
img, kern = node.inputs img, kern = node.inputs
if not isinstance(img, CudaNdarrayType) or \ if not isinstance(img.type, CudaNdarrayType) or \
isinstance(kern, CudaNdarrayType): not isinstance(kern.type, CudaNdarrayType):
print 'here', img.type, kern.type
print isinstance(img, CudaNdarrayType), isinstance(kern, CudaNdarrayType)
return None return None
if node.op.border_mode in ['full', 'valid']: if node.op.border_mode in ['full', 'valid']:
...@@ -500,33 +504,33 @@ def local_conv2d_corrmm(convop, inputs): ...@@ -500,33 +504,33 @@ def local_conv2d_corrmm(convop, inputs):
# call GpuCorrMM_gradInputs # call GpuCorrMM_gradInputs
rval = GpuCorrMM_gradInputs('valid', subsample)( rval = GpuCorrMM_gradInputs('valid', subsample)(
gpu_contiguous(kern), gpu_contiguous(img)) gpu_contiguous(kern), gpu_contiguous(img))
return rval return [rval]
register_specialize_device(local_conv2d_corrmm) register_specialize_device(local_conv2d_corrmm)
@local_optimizer([AbstractConv2d_gradWeights]) @local_optimizer([AbstractConv2d_gradWeights])
def local_conv2d_gradweight_corrmm(node): def local_conv2d_gradweight_corrmm(node):
img, topgrad, shape = node.inputs img, topgrad, shape = node.inputs
if not isinstance(img, CudaNdarrayType) or \ if not isinstance(img.type, CudaNdarrayType) or \
isinstance(topgrad, CudaNdarrayType): not isinstance(topgrad.type, CudaNdarrayType):
return None return None
rval = GpuCorrMM_gradWeights(border_mode=node.op.border_mode, rval = GpuCorrMM_gradWeights(border_mode=node.op.border_mode,
subsample=node.op.subsample)( subsample=node.op.subsample)(
gpu_contiguous(img), gpu_contiguous(topgrad), shape) gpu_contiguous(img), gpu_contiguous(topgrad), shape)
return rval return [rval]
register_specialize_device(local_conv2d_gradweight_corrmm) register_specialize_device(local_conv2d_gradweight_corrmm)
@local_optimizer([AbstractConv2d_gradInputs]) @local_optimizer([AbstractConv2d_gradInputs])
def local_conv2d_gradinputs_corrmm(node): def local_conv2d_gradinputs_corrmm(node):
kern, topgrad, shape = node.inputs kern, topgrad, shape = node.inputs
if not isinstance(img, CudaNdarrayType) or \ if not isinstance(img.type, CudaNdarrayType) or \
isinstance(topgrad, CudaNdarrayType): not isinstance(topgrad.type, CudaNdarrayType):
return None return None
rval = GpuCorrMM_gradInputs(border_mode=node.op.border_mode, rval = GpuCorrMM_gradInputs(border_mode=node.op.border_mode,
subsample=node.op.subsample)( subsample=node.op.subsample)(
gpu_contiguous(kern), gpu_contiguous(topgrad), shape) gpu_contiguous(kern), gpu_contiguous(topgrad), shape)
return rval return [rval]
register_specialize_device(local_conv2d_gradinputs_corrmm) register_specialize_device(local_conv2d_gradinputs_corrmm)
......
...@@ -47,8 +47,9 @@ class TestConv2d(unittest.TestCase): ...@@ -47,8 +47,9 @@ class TestConv2d(unittest.TestCase):
print res_ref.shape, res.shape print res_ref.shape, res.shape
utt.assert_allclose(res_ref, res) utt.assert_allclose(res_ref, res)
if verify_grad: if verify_grad:
utt.verify_grad(conv.Conv2d(border_mode="valid", utt.verify_grad(conv.AbstractConv2d(border_mode="valid",
subsample=subsample), [inputs_val, filters_val]) subsample=subsample),
[inputs_val, filters_val])
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论