提交 5ae86f7b authored 作者: Marius Killinger's avatar Marius Killinger 提交者: Frederic

fixed one typo and added test case to cover the changed case

上级 4c973fcf
......@@ -254,7 +254,7 @@ def conv3d(signals, filters,
# now sum out along the Tf to get the output
# but we have to sum on a diagonal through the Tf and Ts submatrix.
if border_mode[0] == 'valid':
if filters_shape[1]!=1:
if _filters_shape_5d[1]!=1:
out_5d = diagonal_subtensor(out_tmp, 1, 3).sum(axis=3)
else: # for Tf==1, no sum along Tf, the Ts-axis of the output is unchanged!
out_5d = out_tmp.reshape((
......
......@@ -119,3 +119,48 @@ def test_conv3d(mode=mode_without_gpu, shared=theano.tensor._shared):
signals = numpy.random.rand(Ns, Ts, C, Hs, Ws).astype('float32')
filters = numpy.random.rand(Nf, Tf, C, Hf, Wf).astype('float32')
utt.verify_grad(conv3d, [signals, filters], eps=1e-1)
### Additional Test that covers the case of patched implementation for filter with Tf=1
Ns, Ts, C, Hs, Ws = 3, 10, 3, 32, 32
Nf, Tf, C, Hf, Wf = 32, 1 , 3, 5 , 5
signals = numpy.arange(Ns*Ts*C*Hs*Ws).reshape(Ns, Ts, C, Hs, Ws).astype('float32')
filters = numpy.arange(Nf*Tf*C*Hf*Wf).reshape(Nf, Tf, C, Hf, Wf).astype('float32')
t0 = time.time()
pyres = pyconv3d(signals, filters)
print(time.time() - t0)
s_signals = shared(signals)
s_filters = shared(filters)
s_output = shared(signals*0)
out = conv3d(s_signals, s_filters,
signals_shape=signals.shape,
filters_shape=filters.shape)
newconv3d = theano.function([], [],
updates={s_output: out},
mode=mode)
t0 = time.time()
newconv3d()
print(time.time() - t0)
utt.assert_allclose(pyres, s_output.get_value(borrow=True))
gsignals, gfilters = theano.grad(out.sum(), [s_signals, s_filters])
gnewconv3d = theano.function([], [],
updates=[(s_filters, gfilters),
(s_signals, gsignals)],
mode=mode,
name='grad')
t0 = time.time()
gnewconv3d()
print('grad', time.time() - t0)
Ns, Ts, C, Hs, Ws = 3, 3, 3, 5, 5
Nf, Tf, C, Hf, Wf = 4, 1, 3, 2, 2
signals = numpy.random.rand(Ns, Ts, C, Hs, Ws).astype('float32')
filters = numpy.random.rand(Nf, Tf, C, Hf, Wf).astype('float32')
utt.verify_grad(conv3d, [signals, filters], eps=1e-1)
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论