提交 3bbc2f2c authored 作者: Frederic's avatar Frederic

small test speed up

上级 eb4e7715
......@@ -5,24 +5,26 @@ from theano import tensor, function
import unittest
# this tests other ops to ensure they keep the dimensions of their
# inputs correctly
class TestKeepDims(unittest.TestCase):
def makeKeepDims_local(self, x, y, axis):
x = tensor.as_tensor_variable(x)
y = tensor.as_tensor_variable(y)
if axis is None:
axis = numpy.arange(x.ndim)
newaxis = range(x.ndim)
elif isinstance(axis, int):
axis = [axis]
if axis < 0:
newaxis = [axis + x.type.ndim]
else:
newaxis = [axis]
else:
newaxis = []
for a in axis:
if a < 0:
a += x.type.ndim
newaxis.append(a)
i = 0
newaxis = []
for a in axis:
if a < 0:
a += x.type.ndim
newaxis.append(a)
new_dims = []
for j, _ in enumerate(x.shape):
if j in newaxis:
......
Markdown 格式
0%
您添加了 0 到此讨论。请谨慎行事。
请先完成此评论的编辑!
注册 或者 后发表评论