提交 ff8f57f1 authored 作者: abergeron's avatar abergeron

Merge pull request #1858 from nouiz/faster_test

Faster test
from nose.plugins.attrib import attr
import unittest
from nose.plugins.attrib import attr
import numpy
import theano
from theano import tensor, function
import unittest
# this tests other ops to ensure they keep the dimensions of their
# inputs correctly
class TestKeepDims(unittest.TestCase):
def makeKeepDims_local(self, x, y, axis):
x = tensor.as_tensor_variable(x)
y = tensor.as_tensor_variable(y)
if axis is None:
axis = numpy.arange(x.ndim)
newaxis = range(x.ndim)
elif isinstance(axis, int):
axis = [axis]
if axis < 0:
newaxis = [axis + x.type.ndim]
else:
newaxis = [axis]
else:
newaxis = []
for a in axis:
if a < 0:
a += x.type.ndim
newaxis.append(a)
i = 0
newaxis = []
for a in axis:
if a < 0:
a += x.type.ndim
newaxis.append(a)
new_dims = []
for j, _ in enumerate(x.shape):
if j in newaxis:
......@@ -38,6 +41,9 @@ class TestKeepDims(unittest.TestCase):
x = tensor.dtensor3()
a = numpy.random.rand(3, 2, 4)
# We don't need to test all opt and C code, as this is tested
# by the ops tests.
mode = theano.compile.Mode(optimizer="fast_compile", linker="py")
# 'max_and_argmax' has two outputs and can be specified with either
# a single or every axis:
......@@ -46,19 +52,23 @@ class TestKeepDims(unittest.TestCase):
[-2, -3, 2]]:
op = tensor.max_and_argmax
keep_param = function([x], op(x, axis=axis, keepdims=True)[0])
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False)[0], axis))
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
keep_param = function([x], op(x, axis=axis, keepdims=True)[1])
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False)[1], axis))
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
f = function([x], [op(x, axis=axis, keepdims=True)[0],
self.makeKeepDims_local(
x, op(x, axis=axis, keepdims=False)[0],
axis)],
mode=mode)
ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
f = function([x], [op(x, axis=axis, keepdims=True)[1],
self.makeKeepDims_local(
x, op(x, axis=axis, keepdims=False)[1],
axis)],
mode=mode)
ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
# the following ops can be specified with either a single axis or every
# axis:
......@@ -66,38 +76,30 @@ class TestKeepDims(unittest.TestCase):
for axis in [0, 1, 2, [0], [1], [2], None, [0, 1, 2],
[-1], [-2], [-3], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False), axis))
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
keep_param = function([x], op(x, axis=None, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=None, keepdims=False), None))
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
f = function([x], [op(x, axis=axis, keepdims=True),
self.makeKeepDims_local(
x, op(x, axis=axis, keepdims=False),
axis)],
mode=mode)
ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
# the following ops can be specified with a freely specified axis
# parameter
for op in ([tensor.sum, tensor.prod, tensor.mean, tensor.var,
tensor.std, tensor.all, tensor.any,
tensor.max, tensor.min]):
for axis in [0, 1, 2, [0], [1], [2], [0, 1], [1, 2], [0, 1, 2],
for axis in [0, 1, 2, [0], [1], [2], None,
[0, 1], [1, 2], [0, 1, 2],
[-1], [-2], [-3], [-1, -2], [-1, -2, -3], [0, -2, 2]]:
keep_param = function([x], op(x, axis=axis, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=axis, keepdims=False), axis))
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
keep_param = function([x], op(x, axis=None, keepdims=True))
keep_synth = function([x], self.makeKeepDims_local(x,
op(x, axis=None, keepdims=False), None))
f = function([x], [op(x, axis=axis, keepdims=True),
self.makeKeepDims_local(
x, op(x, axis=axis, keepdims=False),
axis)],
mode=mode)
assert numpy.allclose(keep_param(a), keep_synth(a))
assert keep_param(a).shape == keep_synth(a).shape
ans1, ans2 = f(a)
assert numpy.allclose(ans1, ans2)
assert ans1.shape == ans2.shape
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论