提交 b7bcc91a authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1882 from abergeron/cuda_fftconv

Cuda fftconv assert
...@@ -453,6 +453,9 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None, ...@@ -453,6 +453,9 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None,
else: else:
raise ValueError('invalid mode') raise ValueError('invalid mode')
input_padded = T.opt.Assert("in conv2d_fft: width is not even")(
input_padded, T.eq(o1 % 2, 0))
# reshape for FFT # reshape for FFT
input_flat = input_padded.reshape((b * ic, o0, o1)) input_flat = input_padded.reshape((b * ic, o0, o1))
filters_flat = filters_padded.reshape((oc * ic, o0, o1)) filters_flat = filters_padded.reshape((oc * ic, o0, o1))
......
...@@ -1363,6 +1363,9 @@ class Assert(T.Op): ...@@ -1363,6 +1363,9 @@ class Assert(T.Op):
""" """
view_map = {0: [0]} view_map = {0: [0]}
def __init__(self, msg="Theano Assert failed!"):
self.msg = msg
def make_node(self, value, *conds): def make_node(self, value, *conds):
cond = [T.as_tensor_variable(c) for c in conds] cond = [T.as_tensor_variable(c) for c in conds]
assert numpy.all([c.type.ndim == 0 for c in cond]) assert numpy.all([c.type.ndim == 0 for c in cond])
...@@ -1375,13 +1378,13 @@ class Assert(T.Op): ...@@ -1375,13 +1378,13 @@ class Assert(T.Op):
out, = out_ out, = out_
v = inputs[0] v = inputs[0]
out[0] = v out[0] = v
assert numpy.all(inputs[1:]) assert numpy.all(inputs[1:]), self.msg
def __eq__(self, other): def __eq__(self, other):
return type(self) == type(other) return type(self) == type(other) and self.msg == other.msg
def __hash__(self): def __hash__(self):
return hash(type(self)) return hash(type(self)) ^ hash(self.msg)
def grad(self, input, output_gradients): def grad(self, input, output_gradients):
return output_gradients return output_gradients
...@@ -1391,12 +1394,13 @@ class Assert(T.Op): ...@@ -1391,12 +1394,13 @@ class Assert(T.Op):
out = onames[0] out = onames[0]
check = [] check = []
fail = sub['fail'] fail = sub['fail']
msg = self.msg.replace('"', '\\"').replace('\n', '\\n')
for idx in xrange(len(inames) - 1): for idx in xrange(len(inames) - 1):
i = inames[idx + 1] i = inames[idx + 1]
dtype = node.inputs[idx + 1].dtype dtype = node.inputs[idx + 1].dtype
check.append('if(!((npy_%(dtype)s*)PyArray_DATA(%(i)s))[0])' check.append('if(!((npy_%(dtype)s*)PyArray_DATA(%(i)s))[0])'
'{PyErr_SetString(PyExc_AssertionError,"Theano' '{PyErr_SetString(PyExc_AssertionError,"%(msg)s");'
' Assert failed!");%(fail)s}' % locals()) '%(fail)s}' % locals())
check = "\n".join(check) check = "\n".join(check)
return """ return """
%(check)s %(check)s
...@@ -1405,7 +1409,7 @@ class Assert(T.Op): ...@@ -1405,7 +1409,7 @@ class Assert(T.Op):
""" % locals() """ % locals()
def c_code_cache_version(self): def c_code_cache_version(self):
return (0, 1) return (1, 0)
def infer_shape(self, node, input_shapes): def infer_shape(self, node, input_shapes):
return [input_shapes[0]] return [input_shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论