提交 3a6cb4bf authored 作者: Arnaud Bergeron's avatar Arnaud Bergeron

Add a test of a broken case when non-contiguous.

上级 62412499
......@@ -1532,6 +1532,23 @@ def test_dnn_reduction_strides():
yield dnn_reduction_strides, (2, 3, 2), (1, 0, 2), slice(None, None, None)
yield dnn_reduction_strides, (2, 3, 2), (0, 1, 2), slice(None, None, -1)
def test_dnn_reduction_error():
nLoops = 5
vec = np.arange(0, 10, dtype=np.float32)
slow_output = np.zeros((5, 10))
for i in range(nLoops):
slow_output[i, :] = 2.0 * vec
slow_output = np.sum(slow_output.transpose(), axis=1)
vecT = T.vector(dtype=theano.config.floatX)
outputT = T.alloc(2.0 * vecT, 5, vecT.shape[0])
outputSummedT = T.sum(T.transpose(outputT), axis=1)
f3 = theano.function(inputs=[vecT], outputs=outputSummedT)
output = f3(vec)
utt.assert_allclose(slow_output, output)
def dnn_maxargmax(nd, idtype, axis):
inp = T.TensorType(idtype, (False,) * nd)()
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论