提交 d0e68f0c authored 作者: Gijs van Tulder's avatar Gijs van Tulder

Use shorter dimension references in conv3d2d.

上级 d6c3505c
......@@ -221,18 +221,11 @@ def conv3d(signals, filters,
else:
_filters_shape_5d = filters_shape
_signals_shape_4d = (
_signals_shape_5d[0] * _signals_shape_5d[1],
_signals_shape_5d[2],
_signals_shape_5d[3],
_signals_shape_5d[4],
)
_filters_shape_4d = (
_filters_shape_5d[0] * _filters_shape_5d[1],
_filters_shape_5d[2],
_filters_shape_5d[3],
_filters_shape_5d[4],
)
Ns, Ts, C, Hs, Ws = _signals_shape_5d
Nf, Tf, C, Hf, Wf = _filters_shape_5d
_signals_shape_4d = (Ns * Ts, C, Hs, Ws)
_filters_shape_4d = (Nf * Tf, C, Hf, Wf)
if border_mode[1] != border_mode[2]:
raise NotImplementedError('height and width bordermodes must match')
......@@ -251,25 +244,10 @@ def conv3d(signals, filters,
border_mode=border_mode[1]) # ignoring border_mode[2]
# reshape the output to restore its original size
# shape = Ns, Ts, Nf, Tf, W-Wf+1, H-Hf+1
if border_mode[1] == 'valid':
out_tmp = out_4d.reshape((
_signals_shape_5d[0], # Ns
_signals_shape_5d[1], # Ts
_filters_shape_5d[0], # Nf
_filters_shape_5d[1], # Tf
_signals_shape_5d[3] - _filters_shape_5d[3] + 1,
_signals_shape_5d[4] - _filters_shape_5d[4] + 1,
))
out_tmp = out_4d.reshape((Ns, Ts, Nf, Tf, Hs - Hf + 1, Ws - Wf + 1))
elif border_mode[1] == 'full':
out_tmp = out_4d.reshape((
_signals_shape_5d[0], # Ns
_signals_shape_5d[1], # Ts
_filters_shape_5d[0], # Nf
_filters_shape_5d[1], # Tf
_signals_shape_5d[3] + _filters_shape_5d[3] - 1,
_signals_shape_5d[4] + _filters_shape_5d[4] - 1,
))
out_tmp = out_4d.reshape((Ns, Ts, Nf, Tf, Hs + Hf - 1, Ws + Wf - 1))
elif border_mode[1] == 'same':
raise NotImplementedError()
else:
......@@ -281,38 +259,19 @@ def conv3d(signals, filters,
if _filters_shape_5d[1] != 1:
out_5d = diagonal_subtensor(out_tmp, 1, 3).sum(axis=3)
else: # for Tf==1, no sum along Tf, the Ts-axis of the output is unchanged!
out_5d = out_tmp.reshape((
_signals_shape_5d[0],
_signals_shape_5d[1],
_filters_shape_5d[0],
_signals_shape_5d[3] - _filters_shape_5d[3] + 1,
_signals_shape_5d[4] - _filters_shape_5d[4] + 1,
))
out_5d = out_tmp.reshape((Ns, Ts, Nf, Hs - Hf + 1, Ws - Wf + 1))
elif border_mode[0] == 'full':
if _filters_shape_5d[1] != 1:
# pad out_tmp with zeros to have full convolution
out_tmp_padded = tensor.zeros(dtype=out_tmp.dtype, shape=(
_signals_shape_5d[0], # Ns
_signals_shape_5d[1] + 2 * (_filters_shape_5d[1] - 1), # Ts
_filters_shape_5d[0], # Nf
_filters_shape_5d[1], # Tf
_signals_shape_5d[3] + _filters_shape_5d[3] - 1,
_signals_shape_5d[4] + _filters_shape_5d[4] - 1,
Ns, Ts + 2 * (Tf - 1), Nf, Tf, Hs + Hf - 1, Ws + Wf - 1
))
out_tmp_padded = tensor.set_subtensor(
out_tmp_padded[:,
(_filters_shape_5d[1] - 1):(_signals_shape_5d[1] + _filters_shape_5d[1] - 1),
:, :, :, :],
out_tmp_padded[:, (Tf - 1):(Ts + Tf - 1), :, :, :, :],
out_tmp)
out_5d = diagonal_subtensor(out_tmp_padded, 1, 3).sum(axis=3)
else: # for tf==1, no sum along tf, the ts-axis of the output is unchanged!
out_5d = out_tmp.reshape((
_signals_shape_5d[0],
_signals_shape_5d[1],
_filters_shape_5d[0],
_signals_shape_5d[3] + _filters_shape_5d[3] - 1,
_signals_shape_5d[4] + _filters_shape_5d[4] - 1,
))
out_5d = out_tmp.reshape((Ns, Ts, Nf, Hs + Hf - 1, Ws + Wf - 1))
elif border_mode[0] == 'same':
raise NotImplementedError('sequence border mode', border_mode[0])
else:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论