提交 3bbc2f2c authored 作者: Frederic's avatar Frederic

small test speed up

上级 eb4e7715
...@@ -5,24 +5,26 @@ from theano import tensor, function ...@@ -5,24 +5,26 @@ from theano import tensor, function
import unittest import unittest
# this tests other ops to ensure they keep the dimensions of their # this tests other ops to ensure they keep the dimensions of their
# inputs correctly # inputs correctly
class TestKeepDims(unittest.TestCase): class TestKeepDims(unittest.TestCase):
def makeKeepDims_local(self, x, y, axis): def makeKeepDims_local(self, x, y, axis):
x = tensor.as_tensor_variable(x)
y = tensor.as_tensor_variable(y)
if axis is None: if axis is None:
axis = numpy.arange(x.ndim) newaxis = range(x.ndim)
elif isinstance(axis, int): elif isinstance(axis, int):
axis = [axis] if axis < 0:
newaxis = [axis + x.type.ndim]
else:
newaxis = [axis]
else:
newaxis = []
for a in axis:
if a < 0:
a += x.type.ndim
newaxis.append(a)
i = 0 i = 0
newaxis = []
for a in axis:
if a < 0:
a += x.type.ndim
newaxis.append(a)
new_dims = [] new_dims = []
for j, _ in enumerate(x.shape): for j, _ in enumerate(x.shape):
if j in newaxis: if j in newaxis:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论