提交 3b075b19 authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add assert in the conv2d_fft graph.

上级 6aa4a034
...@@ -453,6 +453,9 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None, ...@@ -453,6 +453,9 @@ def conv2d_fft(input, filters, image_shape=None, filter_shape=None,
else: else:
raise ValueError('invalid mode') raise ValueError('invalid mode')
input_padded = T.opt.Assert("in conv2d_fft: width is not even")(
input_padded, o1 % 2 == 0)
# reshape for FFT # reshape for FFT
input_flat = input_padded.reshape((b * ic, o0, o1)) input_flat = input_padded.reshape((b * ic, o0, o1))
filters_flat = filters_padded.reshape((oc * ic, o0, o1)) filters_flat = filters_padded.reshape((oc * ic, o0, o1))
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论