提交 6aeec3a3 authored 作者: Frederic's avatar Frederic

add support for not constant shape in conv3d2d.

This is needed for a later commit that add call to utt.verify_grad.
上级 1c8357ce
......@@ -125,9 +125,6 @@ def conv3d(signals, filters,
if isinstance(border_mode, str):
border_mode = (border_mode, border_mode, border_mode)
#TODO: support variables in the shape
if signals_shape is None or filters_shape is None:
raise NotImplementedError('need shapes for now')
_signals_shape_5d = signals.shape if signals_shape is None else signals_shape
_filters_shape_5d = filters.shape if filters_shape is None else filters_shape
......@@ -146,12 +143,18 @@ def conv3d(signals, filters,
if border_mode[1] != border_mode[2]:
raise NotImplementedError('height and width bordermodes must match')
conv2d_signal_shape = _signals_shape_4d
conv2d_filter_shape = _filters_shape_4d
if signals_shape is None:
conv2d_signal_shape = None
if filters_shape is None:
conv2d_filter_shape = None
out_4d = tensor.nnet.conv2d(
signals.reshape(_signals_shape_4d),
filters.reshape(_filters_shape_4d),
image_shape=_signals_shape_4d,
filter_shape=_filters_shape_4d,
image_shape=conv2d_signal_shape,
filter_shape=conv2d_filter_shape,
border_mode = border_mode[1]) # ignoring border_mode[2]
# reshape the output to restore its original size
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论