提交 b7bcc91a authored 作者: Frédéric Bastien's avatar Frédéric Bastien

Merge pull request #1882 from abergeron/cuda_fftconv

Cuda fftconv assert
......@@ -453,6 +453,9 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None,
else:
raise ValueError('invalid mode')
input_padded = T.opt.Assert("in conv2d_fft: width is not even")(
input_padded, T.eq(o1 % 2, 0))
# reshape for FFT
input_flat = input_padded.reshape((b * ic, o0, o1))
filters_flat = filters_padded.reshape((oc * ic, o0, o1))
......
......@@ -1363,6 +1363,9 @@ class Assert(T.Op):
"""
view_map = {0: [0]}
def __init__(self, msg="Theano Assert failed!"):
self.msg = msg
def make_node(self, value, *conds):
cond = [T.as_tensor_variable(c) for c in conds]
assert numpy.all([c.type.ndim == 0 for c in cond])
......@@ -1375,13 +1378,13 @@ class Assert(T.Op):
out, = out_
v = inputs[0]
out[0] = v
assert numpy.all(inputs[1:])
assert numpy.all(inputs[1:]), self.msg
def __eq__(self, other):
return type(self) == type(other)
return type(self) == type(other) and self.msg == other.msg
def __hash__(self):
return hash(type(self))
return hash(type(self)) ^ hash(self.msg)
def grad(self, input, output_gradients):
return output_gradients
......@@ -1391,12 +1394,13 @@ class Assert(T.Op):
out = onames[0]
check = []
fail = sub['fail']
msg = self.msg.replace('"', '\\"').replace('\n', '\\n')
for idx in xrange(len(inames) - 1):
i = inames[idx + 1]
dtype = node.inputs[idx + 1].dtype
check.append('if(!((npy_%(dtype)s*)PyArray_DATA(%(i)s))[0])'
'{PyErr_SetString(PyExc_AssertionError,"Theano'
' Assert failed!");%(fail)s}' % locals())
'{PyErr_SetString(PyExc_AssertionError,"%(msg)s");'
'%(fail)s}' % locals())
check = "\n".join(check)
return """
%(check)s
......@@ -1405,7 +1409,7 @@ class Assert(T.Op):
""" % locals()
def c_code_cache_version(self):
return (0, 1)
return (1, 0)
def infer_shape(self, node, input_shapes):
return [input_shapes[0]]
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论