提交 5fe9e8d5 authored 作者: Frederic Bastien's avatar Frederic Bastien

1 more assert and better error message.

上级 6a3595f2
...@@ -90,7 +90,8 @@ class ConvOp(Op): ...@@ -90,7 +90,8 @@ class ConvOp(Op):
outdim = kerns.ndim outdim = kerns.ndim
print inputs.type.dtype, kerns.type.dtype print inputs.type.dtype, kerns.type.dtype
if inputs.type.dtype != kerns.type.dtype: if inputs.type.dtype != kerns.type.dtype:
raise Exception("The image and the kernel must have the same type.") raise Exception("The image and the kernel must have the same type."
"inputs(%s), kerns(%s)"%(inputs.dtype, kerns.dtype))
output = tensor.tensor(dtype=inputs.type.dtype, output = tensor.tensor(dtype=inputs.type.dtype,
broadcastable=[False]*outdim, broadcastable=[False]*outdim,
name="ConvOp_Output"); name="ConvOp_Output");
...@@ -135,6 +136,7 @@ class ConvOp(Op): ...@@ -135,6 +136,7 @@ class ConvOp(Op):
####### Determine gradient on kernels ######## ####### Determine gradient on kernels ########
if inputs.ndim == 3: if inputs.ndim == 3:
inputs = tensor.shape_padleft(inputs,1) inputs = tensor.shape_padleft(inputs,1)
assert inputs.ndim==4 and kerns.ndim==4
newin = tensor.DimShuffle(inputs.broadcastable, (1,0,2,3))(inputs) newin = tensor.DimShuffle(inputs.broadcastable, (1,0,2,3))(inputs)
newgz = tensor.DimShuffle(gz.broadcastable, (1,0,2,3))(gz) newgz = tensor.DimShuffle(gz.broadcastable, (1,0,2,3))(gz)
...@@ -313,7 +315,11 @@ if(%(filtersflipped)s->nd==3){ ...@@ -313,7 +315,11 @@ if(%(filtersflipped)s->nd==3){
kerns_dim[1]=%(filtersflipped)s->dimensions[1]; kerns_dim[1]=%(filtersflipped)s->dimensions[1];
kerns_dim[0]=%(filtersflipped)s->dimensions[0]; kerns_dim[0]=%(filtersflipped)s->dimensions[0];
}else{ }else{
PyErr_SetString(PyExc_ValueError, "kernel don't have a good shape"); std:stringstream temp;
temp << "nddim="<<%(filtersflipped)s->nd;
std::string param = temp.str();
PyErr_SetString(PyExc_ValueError,
("kernel don't have a good shape. " + param).c_str());
%(fail)s; %(fail)s;
} }
...@@ -519,7 +525,11 @@ if(%(filtersflipped)s->nd==3){ ...@@ -519,7 +525,11 @@ if(%(filtersflipped)s->nd==3){
kerns_dim[1]=%(filtersflipped)s->dimensions[1]; kerns_dim[1]=%(filtersflipped)s->dimensions[1];
kerns_dim[0]=%(filtersflipped)s->dimensions[0]; kerns_dim[0]=%(filtersflipped)s->dimensions[0];
}else{ }else{
PyErr_SetString(PyExc_ValueError, "kernel don't have a good shape"); std:stringstream temp;
temp << "nddim="<<%(filtersflipped)s->nd;
std::string param = temp.str();
PyErr_SetString(PyExc_ValueError,
("kernel don't have a good shape. " + param).c_str());
%(fail)s; %(fail)s;
} }
if (NKERN != kerns_dim[0]) if (NKERN != kerns_dim[0])
...@@ -761,7 +771,11 @@ if(%(filtersflipped)s->nd==3){ ...@@ -761,7 +771,11 @@ if(%(filtersflipped)s->nd==3){
kerns_dim[1]=%(filtersflipped)s->dimensions[1]; kerns_dim[1]=%(filtersflipped)s->dimensions[1];
kerns_dim[0]=%(filtersflipped)s->dimensions[0]; kerns_dim[0]=%(filtersflipped)s->dimensions[0];
}else{ }else{
PyErr_SetString(PyExc_ValueError, "kernel don't have a good shape"); std:stringstream temp;
temp << "nddim="<<%(filtersflipped)s->nd;
std::string param = temp.str();
PyErr_SetString(PyExc_ValueError,
("kernel don't have a good shape. " + param).c_str());
%(fail)s; %(fail)s;
} }
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论